diff --git a/modules/text_generation.py b/modules/text_generation.py index 585e4d9d..40046eb2 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -215,20 +215,21 @@ def formatted_outputs(reply, model_name): def set_manual_seed(seed): - import torch - from transformers import is_torch_npu_available, is_torch_xpu_available - seed = int(seed) if seed == -1: seed = random.randint(1, 2**31) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - elif is_torch_xpu_available(): - torch.xpu.manual_seed_all(seed) - elif is_torch_npu_available(): - torch.npu.manual_seed_all(seed) + if shared.args.loader != 'llama.cpp': + import torch + from transformers import is_torch_npu_available, is_torch_xpu_available + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + elif is_torch_xpu_available(): + torch.xpu.manual_seed_all(seed) + elif is_torch_npu_available(): + torch.npu.manual_seed_all(seed) return seed @@ -476,11 +477,7 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N For models that do not use the transformers library for sampling """ - seed = state['seed'] - if shared.args.loader != 'llama.cpp': - print(shared.args.loader) - seed = set_manual_seed(seed) - + seed = set_manual_seed(state['seed']) t0 = time.time() reply = '' try: