mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Revert "Same as 7f06aec3a1 but for exllamav3_hf"
This reverts commit deb37b821b.
This commit is contained in:
parent
163d863443
commit
c871d9cdbd
|
|
@ -103,12 +103,6 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
labels = kwargs.get('labels', None)
|
labels = kwargs.get('labels', None)
|
||||||
past_key_values = kwargs.get('past_key_values', None)
|
past_key_values = kwargs.get('past_key_values', None)
|
||||||
|
|
||||||
# Reset the internal sequence state for standalone calls (logit viewer)
|
|
||||||
# or the very first step of a new generation.
|
|
||||||
if past_key_values is None:
|
|
||||||
self.past_seq = None
|
|
||||||
self.past_seq_negative = None
|
|
||||||
|
|
||||||
if len(args) > 0:
|
if len(args) > 0:
|
||||||
if not shared.args.cfg_cache:
|
if not shared.args.cfg_cache:
|
||||||
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.")
|
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.")
|
||||||
|
|
@ -125,8 +119,8 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
ex_cache = self.ex_cache
|
ex_cache = self.ex_cache
|
||||||
|
|
||||||
seq = input_ids[0].tolist()
|
seq = input_ids[0].tolist()
|
||||||
if is_negative and past_key_values is not None and isinstance(past_key_values, list):
|
if is_negative and past_key_values is not None:
|
||||||
seq = past_key_values + seq
|
seq = past_key_values + seq
|
||||||
|
|
||||||
seq_tensor = torch.tensor(seq)
|
seq_tensor = torch.tensor(seq)
|
||||||
reset = True
|
reset = True
|
||||||
|
|
@ -134,50 +128,97 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
# Maximum number of tokens to process in a single forward pass
|
# Maximum number of tokens to process in a single forward pass
|
||||||
max_chunk_size = 256
|
max_chunk_size = 256
|
||||||
|
|
||||||
if past_seq is not None:
|
|
||||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
|
||||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
|
||||||
if len(indices) == 0 and seq_tensor.shape[0] > past_seq.shape[0]:
|
|
||||||
reset = False
|
|
||||||
|
|
||||||
# Create a single `params` dictionary that will be used and modified
|
|
||||||
# in-place across all `forward` calls within this function.
|
|
||||||
params = {
|
|
||||||
"attn_mode": "flash_attn",
|
|
||||||
"cache": ex_cache,
|
|
||||||
"batch_shape": (1, self.max_tokens),
|
|
||||||
"reconstruct": False,
|
|
||||||
"past_len": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make the forward call
|
# Make the forward call
|
||||||
if labels is None:
|
if labels is None:
|
||||||
# If it's an efficient continuation, process only the new tokens
|
if past_seq is not None:
|
||||||
if not reset:
|
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||||
params["past_len"] = past_seq.shape[0]
|
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||||
tokens_to_process = seq_tensor[past_seq.shape[0]:]
|
if len(indices) > 0:
|
||||||
# Otherwise, process the whole sequence from scratch
|
longest_prefix = indices[0].item()
|
||||||
else:
|
else:
|
||||||
tokens_to_process = seq_tensor
|
longest_prefix = min_length
|
||||||
|
|
||||||
# Process all but the last token of the sequence/sub-sequence
|
if longest_prefix > 0:
|
||||||
if tokens_to_process.shape[0] > 1:
|
reset = False
|
||||||
prefix_to_process = tokens_to_process[:-1]
|
current_len = longest_prefix
|
||||||
|
remaining_tokens = len(seq_tensor) - longest_prefix - 1
|
||||||
|
|
||||||
# Process in chunks if the number of tokens is large
|
if remaining_tokens > 0:
|
||||||
for i in range(0, prefix_to_process.shape[0], max_chunk_size):
|
# Process tokens from longest_prefix to second-to-last token
|
||||||
chunk = prefix_to_process[i:i + max_chunk_size]
|
tokens_to_process = seq_tensor[longest_prefix:-1]
|
||||||
self.ex_model.forward(input_ids=chunk.view(1, -1), params=params)
|
|
||||||
params["past_len"] += chunk.shape[0]
|
|
||||||
|
|
||||||
# Process the last token to get logits
|
# Process in chunks if the number of tokens is large
|
||||||
last_token = tokens_to_process[-1:].view(1, -1)
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
logits = self.ex_model.forward(input_ids=last_token, params=params).to(input_ids.device).float()
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
|
self.ex_model.forward(
|
||||||
|
input_ids=chunk.view(1, -1),
|
||||||
|
params={
|
||||||
|
"attn_mode": "flash_attn",
|
||||||
|
"cache": ex_cache,
|
||||||
|
"past_len": longest_prefix + i,
|
||||||
|
"batch_shape": (1, self.max_tokens),
|
||||||
|
"reconstruct": False # Force memory-efficient path
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_len = longest_prefix + remaining_tokens
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
if len(seq_tensor) > 1:
|
||||||
|
# Process all tokens except the last one
|
||||||
|
tokens_to_process = seq_tensor[:-1]
|
||||||
|
|
||||||
|
# Process in chunks if the number of tokens is large
|
||||||
|
current_len = 0
|
||||||
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
|
self.ex_model.forward(
|
||||||
|
input_ids=chunk.view(1, -1),
|
||||||
|
params={
|
||||||
|
"attn_mode": "flash_attn",
|
||||||
|
"cache": ex_cache,
|
||||||
|
"past_len": current_len,
|
||||||
|
"batch_shape": (1, self.max_tokens),
|
||||||
|
"reconstruct": False # Force memory-efficient path
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_len += chunk.shape[0]
|
||||||
|
else:
|
||||||
|
current_len = 0
|
||||||
|
|
||||||
|
# Process the last token and get logits
|
||||||
|
logits = self.ex_model.forward(
|
||||||
|
input_ids=seq_tensor[-1:].view(1, -1),
|
||||||
|
params={
|
||||||
|
"attn_mode": "flash_attn",
|
||||||
|
"cache": ex_cache,
|
||||||
|
"past_len": current_len,
|
||||||
|
"batch_shape": (1, self.max_tokens),
|
||||||
|
"reconstruct": False # Force memory-efficient path
|
||||||
|
}
|
||||||
|
).to(input_ids.device).float()
|
||||||
else:
|
else:
|
||||||
# When processing with labels, handle as a complete sequence
|
# When processing with labels, handle as a complete sequence
|
||||||
params["attn_mode"] = "flash_attn_nc"
|
# Process in chunks if the number of tokens is large
|
||||||
logits = self.ex_model.forward(input_ids=seq_tensor.view(1,-1), params=params).float()
|
tokens_to_process = seq_tensor
|
||||||
|
all_logits = None
|
||||||
|
|
||||||
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
|
chunk_logits = self.ex_model.forward(
|
||||||
|
input_ids=chunk.view(1, -1),
|
||||||
|
params={
|
||||||
|
"attn_mode": "flash_attn_nc", # No caching for training
|
||||||
|
"reconstruct": False # Force memory-efficient path
|
||||||
|
}
|
||||||
|
).float()
|
||||||
|
|
||||||
|
if all_logits is None:
|
||||||
|
all_logits = chunk_logits
|
||||||
|
else:
|
||||||
|
all_logits = torch.cat([all_logits, chunk_logits], dim=1)
|
||||||
|
|
||||||
|
logits = all_logits
|
||||||
|
|
||||||
if is_negative:
|
if is_negative:
|
||||||
self.past_seq_negative = seq_tensor
|
self.past_seq_negative = seq_tensor
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue