namespace-Pt
commited on
Commit
•
3189372
1
Parent(s):
1b8bc55
Update README.md
Browse files
README.md
CHANGED
@@ -62,12 +62,13 @@ with torch.no_grad():
|
|
62 |
with open("data/toy/narrativeqa.json", encoding="utf-8") as f:
|
63 |
example = json.load(f)
|
64 |
inputs = tokenizer(example["context"], return_tensors="pt").to("cuda")
|
65 |
-
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
|
66 |
print("*"*20)
|
67 |
print(f"Input Length: {inputs['input_ids'].shape[1]}")
|
68 |
print(f"Answer: {example['answer']}")
|
69 |
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
|
70 |
```
|
|
|
71 |
|
72 |
## Training
|
73 |
*coming soon*
|
|
|
62 |
with open("data/toy/narrativeqa.json", encoding="utf-8") as f:
|
63 |
example = json.load(f)
|
64 |
inputs = tokenizer(example["context"], return_tensors="pt").to("cuda")
|
65 |
+
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
|
66 |
print("*"*20)
|
67 |
print(f"Input Length: {inputs['input_ids'].shape[1]}")
|
68 |
print(f"Answer: {example['answer']}")
|
69 |
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
|
70 |
```
|
71 |
+
**NOTE**: It's okay to see warnings like `This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.` Just ignore it.
|
72 |
|
73 |
## Training
|
74 |
*coming soon*
|