From 2bf8788c3036f4e83677f41c225af4fa868d9b7a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:31:06 -0800 Subject: [PATCH] Installer: Fix a bug after ecb5d3c48545a9d3ad41cd34bd77767e93f6ed3b --- .gitignore | 1 + one_click.py | 34 +++++++++++++++++++++++----------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index ca307c4a..7d1099b6 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ venv .direnv .vs .vscode +.wheels_changed_flag *.bak *.ipynb *.log diff --git a/one_click.py b/one_click.py index 4910f8c7..04a488a0 100644 --- a/one_click.py +++ b/one_click.py @@ -362,13 +362,17 @@ def update_requirements(initial_installation=False, pull=True): requirements_file = base_requirements - # Call git pull - before_pull_whl_lines = [] - if os.path.exists(requirements_file): - with open(requirements_file, 'r') as f: - before_pull_whl_lines = [line for line in f if '.whl' in line] + # Call git pull, while checking if .whl requirements have changed + wheels_changed_from_flag = False + if os.path.exists('.wheels_changed_flag'): + os.remove('.wheels_changed_flag') + wheels_changed_from_flag = True if pull: + if os.path.exists(requirements_file): + with open(requirements_file, 'r') as f: + before_pull_whl_lines = [line for line in f if '.whl' in line] + print_big_message("Updating the local copy of the repository with \"git pull\"") files_to_check = [ @@ -381,16 +385,25 @@ def update_requirements(initial_installation=False, pull=True): run_cmd("git pull --autostash", assert_success=True, environment=True) after_pull_hashes = {file_name: calculate_file_hash(file_name) for file_name in files_to_check} + if os.path.exists(requirements_file): + with open(requirements_file, 'r') as f: + after_pull_whl_lines = [line for line in f if '.whl' in line] + # Check for differences in installation file hashes for file_name in files_to_check: if before_pull_hashes[file_name] != after_pull_hashes[file_name]: print_big_message(f"File '{file_name}' was updated during 'git pull'. Please run the script again.") + + # Check if wheels changed during this pull + wheels_changed = before_pull_whl_lines != after_pull_whl_lines + if wheels_changed: + open('.wheels_changed_flag', 'w').close() + exit(1) - after_pull_whl_lines = [] - if os.path.exists(requirements_file): - with open(requirements_file, 'r') as f: - after_pull_whl_lines = [line for line in f if '.whl' in line] + wheels_changed = wheels_changed_from_flag + if pull: + wheels_changed = wheels_changed or (before_pull_whl_lines != after_pull_whl_lines) if os.environ.get("INSTALL_EXTENSIONS", "").lower() in ("yes", "y", "true", "1", "t", "on"): install_extensions_requirements() @@ -405,8 +418,7 @@ def update_requirements(initial_installation=False, pull=True): # Prepare the requirements file textgen_requirements = open(requirements_file).read().splitlines() - whl_changed = before_pull_whl_lines != after_pull_whl_lines - if not initial_installation and not whl_changed: + if not initial_installation and not wheels_changed: textgen_requirements = [line for line in textgen_requirements if not '.whl' in line] if is_cuda118: