diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 4aa46375..4ba18590 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -35,7 +35,9 @@ except Exception: class Exllamav2HF(PreTrainedModel, GenerationMixin): def __init__(self, config: ExLlamaV2Config): - super().__init__(PretrainedConfig()) + hf_config = PretrainedConfig.from_pretrained(config.model_dir) + super().__init__(hf_config) + self.ex_config = config self.loras = None self.generation_config = GenerationConfig() diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index d9f4ed57..05b473b7 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -27,11 +27,13 @@ except Exception: class Exllamav3HF(PreTrainedModel, GenerationMixin): def __init__(self, model_dir): - super().__init__(PretrainedConfig()) - self.generation_config = GenerationConfig() + hf_config = PretrainedConfig.from_pretrained(model_dir) + super().__init__(hf_config) - config = Config.from_directory(model_dir) - self.ex_model = Model.from_config(config) + exl3_config = Config.from_directory(model_dir) + + self.generation_config = GenerationConfig() + self.ex_model = Model.from_config(exl3_config) # Calculate the closest multiple of 256 at or above the chosen value max_tokens = shared.args.ctx_size