JamePeng2023 commited on
Commit
98dd963
1 Parent(s): 6fcf8b4

Add streaming support for text generation

Browse files

- Implemented streaming functionality for real-time text output.
- Added `_decode_stream` method to handle text streaming.
- Updated `chat` method to support streaming mode.
- Adjusted code to process and yield text in chunks for better responsiveness.

This update enhances the user experience by allowing incremental text generation and display.

Files changed (1) hide show
  1. modeling_minicpm.py +64 -8
modeling_minicpm.py CHANGED
@@ -22,12 +22,14 @@ import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
 
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
 
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache
33
  from transformers.modeling_attn_mask_utils import (
@@ -1248,6 +1250,9 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1248
  self.vocab_size = config.vocab_size
1249
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1250
 
 
 
 
1251
  # Initialize weights and apply final processing
1252
  self.post_init()
1253
 
@@ -1426,11 +1431,52 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1426
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1427
  )
1428
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1429
 
1430
  @torch.inference_mode()
1431
- def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1432
- max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1433
- **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434
  if history is None:
1435
  history = []
1436
  if logits_processor:
@@ -1443,12 +1489,22 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1443
  history.append({"role": role, "content": query})
1444
  history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
1445
  inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1446
- outputs = self.generate(**inputs, **gen_kwargs)
1447
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1448
- response = tokenizer.decode(outputs)
1449
- history.append({"role": "assistant", "content": response})
1450
- return response, history
1451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1452
 
1453
  @add_start_docstrings(
1454
  """
 
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
 
25
+ from threading import Thread
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
+ from transformers import TextIteratorStreamer
33
  from transformers.activations import ACT2FN
34
  from transformers.cache_utils import Cache, DynamicCache
35
  from transformers.modeling_attn_mask_utils import (
 
1250
  self.vocab_size = config.vocab_size
1251
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1252
 
1253
+ # List of terminator tokens used to indicate the end of a sequence or conversation.
1254
+ self.terminators = ['</s>', '<|im_end|>']
1255
+
1256
  # Initialize weights and apply final processing
1257
  self.post_init()
1258
 
 
1431
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1432
  )
1433
  return reordered_past
1434
+
1435
+ # Internal function to handle streaming of generated text using TextIteratorStreamer.
1436
+ def _decode_stream(self, input_ids, tokenizer, **kwargs):
1437
+ # Convert terminators to token IDs
1438
+ terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1439
+ # Initialize TextIteratorStreamer for handling streaming output
1440
+ streamer = TextIteratorStreamer(tokenizer=tokenizer,skip_prompt=True, skip_special_tokens=True)
1441
+ # Set up generation parameters, including input IDs, eos token IDs, and streamer
1442
+ generation_kwargs = {
1443
+ 'input_ids': input_ids,
1444
+ 'eos_token_id': terminators_ids,
1445
+ 'streamer': streamer
1446
+ }
1447
+ generation_kwargs.update(kwargs)
1448
+ # Run the generation task in a separate thread to enable streaming output
1449
+ thread = Thread(target=self.generate, kwargs=generation_kwargs)
1450
+ thread.start()
1451
+ # Return the streamer instance for later access to streamed text
1452
+ return streamer
1453
+
1454
 
1455
  @torch.inference_mode()
1456
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", max_length: int = 4096, num_beams=1,
1457
+ do_sample=True, logits_processor=None, stream=False, top_p=0.8, temperature=0.3, **kwargs):
1458
+ """
1459
+ Main function for handling dialogue generation based on the input query and history.
1460
+
1461
+ Parameters:
1462
+ - tokenizer: Tokenizer instance used for encoding and decoding.
1463
+ - query: The user input query string.
1464
+ - history: Dialogue history, a list of dictionaries where each dictionary contains role and content.
1465
+ - role: The current role, default is "user".
1466
+ - max_length: Maximum length of the generated text.
1467
+ - num_beams: Number of beams for beam search.
1468
+ - do_sample: Whether to use sampling for generation.
1469
+ - logits_processor: Function for processing logits (if any).
1470
+ - stream: Whether to use streaming output.
1471
+ - top_p: Nucleus sampling parameter.
1472
+ - temperature: Temperature parameter for generation.
1473
+ - **kwargs: Additional arguments for generation.
1474
+
1475
+ Returns:
1476
+ - If stream is True, returns a generator function to get the generated text incrementally.
1477
+ - If stream is False, returns the complete generated response string.
1478
+ """
1479
+
1480
  if history is None:
1481
  history = []
1482
  if logits_processor:
 
1489
  history.append({"role": role, "content": query})
1490
  history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
1491
  inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
 
 
 
 
 
1492
 
1493
+ if stream:
1494
+ res = self._decode_stream(inputs["input_ids"], tokenizer, **gen_kwargs)
1495
+ def stream_gen():
1496
+ for text in res:
1497
+ # Remove terminators from the text
1498
+ for term in self.terminators:
1499
+ text = text.replace(term, '')
1500
+ yield text
1501
+ return stream_gen()
1502
+
1503
+ else:
1504
+ outputs = self.generate(**inputs, **gen_kwargs)
1505
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1506
+ response = tokenizer.decode(outputs)
1507
+ return response
1508
 
1509
  @add_start_docstrings(
1510
  """