mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
Organize internals (#6646)
This commit is contained in:
parent
17aa97248f
commit
83c426e96b
6 changed files with 346 additions and 310 deletions
|
|
@ -287,31 +287,62 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
clear_torch_cache()
|
||||
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
|
||||
for k in [
|
||||
'temperature',
|
||||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'smoothing_curve',
|
||||
'min_p',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'typical_p',
|
||||
'xtc_threshold',
|
||||
'xtc_probability',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
'repetition_penalty',
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'repetition_penalty_range',
|
||||
'penalty_alpha',
|
||||
'guidance_scale',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'max_new_tokens',
|
||||
'do_sample',
|
||||
'dynamic_temperature',
|
||||
'temperature_last',
|
||||
'dry_sequence_breakers',
|
||||
]:
|
||||
if k in state:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if isinstance(state['sampler_priority'], list) and len(state['sampler_priority']) > 0:
|
||||
generate_params['sampler_priority'] = state['sampler_priority']
|
||||
elif isinstance(state['sampler_priority'], str) and state['sampler_priority'].strip() != '':
|
||||
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||
|
||||
if state['prompt_lookup_num_tokens'] > 0:
|
||||
generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens']
|
||||
|
||||
if state['static_cache']:
|
||||
generate_params['cache_implementation'] = 'static'
|
||||
|
||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||
if state[k] > 0:
|
||||
generate_params[k] = state[k] * 1e-4
|
||||
|
||||
if state['prompt_lookup_num_tokens'] > 0:
|
||||
generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens']
|
||||
|
||||
if state['ban_eos_token']:
|
||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||
|
||||
if state['static_cache']:
|
||||
generate_params['cache_implementation'] = 'static'
|
||||
|
||||
if isinstance(state['sampler_priority'], list) and len(state['sampler_priority']) > 0:
|
||||
generate_params['sampler_priority'] = state['sampler_priority']
|
||||
elif isinstance(state['sampler_priority'], str) and state['sampler_priority'].strip() != '':
|
||||
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
||||
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||
if len(to_ban) > 0:
|
||||
|
|
@ -320,6 +351,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
else:
|
||||
generate_params['suppress_tokens'] = to_ban
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||
|
||||
generate_params.update({'use_cache': not shared.args.no_cache})
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({'synced_gpus': True})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue