Organize internals (#6646)

This commit is contained in:
oobabooga 2025-01-10 18:04:32 -03:00 committed by GitHub
parent 17aa97248f
commit 83c426e96b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 346 additions and 310 deletions

View file

@ -287,31 +287,62 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
clear_torch_cache()
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
for k in [
'temperature',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'min_p',
'top_p',
'top_k',
'typical_p',
'xtc_threshold',
'xtc_probability',
'tfs',
'top_a',
'dry_multiplier',
'dry_allowed_length',
'dry_base',
'repetition_penalty',
'frequency_penalty',
'presence_penalty',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'repetition_penalty_range',
'penalty_alpha',
'guidance_scale',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'max_new_tokens',
'do_sample',
'dynamic_temperature',
'temperature_last',
'dry_sequence_breakers',
]:
if k in state:
generate_params[k] = state[k]
if isinstance(state['sampler_priority'], list) and len(state['sampler_priority']) > 0:
generate_params['sampler_priority'] = state['sampler_priority']
elif isinstance(state['sampler_priority'], str) and state['sampler_priority'].strip() != '':
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
if state['negative_prompt'] != '':
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
if state['prompt_lookup_num_tokens'] > 0:
generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens']
if state['static_cache']:
generate_params['cache_implementation'] = 'static'
for k in ['epsilon_cutoff', 'eta_cutoff']:
if state[k] > 0:
generate_params[k] = state[k] * 1e-4
if state['prompt_lookup_num_tokens'] > 0:
generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens']
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
if state['static_cache']:
generate_params['cache_implementation'] = 'static'
if isinstance(state['sampler_priority'], list) and len(state['sampler_priority']) > 0:
generate_params['sampler_priority'] = state['sampler_priority']
elif isinstance(state['sampler_priority'], str) and state['sampler_priority'].strip() != '':
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
@ -320,6 +351,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
else:
generate_params['suppress_tokens'] = to_ban
if state['negative_prompt'] != '':
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
generate_params.update({'use_cache': not shared.args.no_cache})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})