namespace-Pt commited on
Commit
3189372
1 Parent(s): 1b8bc55

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -1
README.md CHANGED
@@ -62,12 +62,13 @@ with torch.no_grad():
62
  with open("data/toy/narrativeqa.json", encoding="utf-8") as f:
63
  example = json.load(f)
64
  inputs = tokenizer(example["context"], return_tensors="pt").to("cuda")
65
- outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
66
  print("*"*20)
67
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
68
  print(f"Answer: {example['answer']}")
69
  print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
70
  ```
 
71
 
72
  ## Training
73
  *coming soon*
 
62
  with open("data/toy/narrativeqa.json", encoding="utf-8") as f:
63
  example = json.load(f)
64
  inputs = tokenizer(example["context"], return_tensors="pt").to("cuda")
65
+ outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
66
  print("*"*20)
67
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
68
  print(f"Answer: {example['answer']}")
69
  print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
70
  ```
71
+ **NOTE**: It's okay to see warnings like `This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.` Just ignore it.
72
 
73
  ## Training
74
  *coming soon*