mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-14 01:23:51 +01:00
Fix ExLlamaV2_HF and ExLlamaV3_HF after ae02ffc605
This commit is contained in:
parent
9c59acf820
commit
b3bf7a885d
|
|
@ -4,10 +4,6 @@ from pathlib import Path
|
|||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
|
|
@ -18,6 +14,15 @@ from exllamav2 import (
|
|||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Config
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
|
@ -28,7 +33,7 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class Exllamav2HF(PreTrainedModel):
|
||||
class Exllamav2HF(PreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: ExLlamaV2Config):
|
||||
super().__init__(PretrainedConfig())
|
||||
self.ex_config = config
|
||||
|
|
|
|||
|
|
@ -6,7 +6,12 @@ from typing import Any, Dict, Optional, Union
|
|||
import torch
|
||||
from exllamav3 import Cache, Config, Model
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
|
|
@ -19,7 +24,7 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class Exllamav3HF(PreTrainedModel):
|
||||
class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||
def __init__(self, model_dir):
|
||||
super().__init__(PretrainedConfig())
|
||||
self.generation_config = GenerationConfig()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -34,6 +35,10 @@ def load_model(model_name, loader=None):
|
|||
logger.error('The path to the model does not exist. Exiting.')
|
||||
raise ValueError
|
||||
|
||||
if loader != 'llama.cpp' and 'sampler_hijack' not in sys.modules:
|
||||
from modules import sampler_hijack
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
shared.args.loader = loader
|
||||
output = load_func_map[loader](model_name)
|
||||
if type(output) is tuple:
|
||||
|
|
|
|||
|
|
@ -22,13 +22,11 @@ from transformers import (
|
|||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_reply_from_output_ids
|
||||
from modules.torch_utils import get_device
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
local_rank = None
|
||||
if shared.args.deepspeed:
|
||||
|
|
|
|||
Loading…
Reference in a new issue