Skip to content

Commit

Permalink
Fix GPT text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
maanug-nv authored and jaredcasper committed Mar 9, 2023
1 parent ef59b68 commit 3c76018
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ python tools/checkpoint_util.py \
--load-dir checkpoints/gpt3_tp4_pp4 \
--save-dir checkpoints/gpt3_tp2_pp2 \
--target-tensor-parallel-size 2 \
--target-pipeline-paralle-size 2
--target-pipeline-parallel-size 2

</pre>

Expand All @@ -430,7 +430,7 @@ We have included a simple REST server to use for text generation in `tools/run_t
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.

<pre>
tools/text_generation_cli.py localhost
tools/text_generation_cli.py localhost:5000
</pre>

You can also use CURL or any other tools to query the server directly:
Expand Down
4 changes: 3 additions & 1 deletion examples/run_text_generation_server_345M.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ CHECKPOINT=<Path to checkpoint (e.g /345m)>
VOCAB_FILE=<Path to vocab.json (e.g. /gpt2-vocab.json)>
MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>

export CUDA_DEVICE_MAX_CONNECTIONS=1

pip install flask-restful

python -m torch.distributed.run $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
Expand Down
30 changes: 16 additions & 14 deletions tools/text_generation_cli.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import sys
import urllib2
class PutRequest(urllib2.Request):
'''class to handling putting with urllib2'''
import json
import requests

def get_method(self, *args, **kwargs):
return 'PUT'

if __name__ == "__main__":
url = sys.argv[1]
url = 'http://' + url + '/api'
headers = {'Content-Type': 'application/json'}

while True:
sentence = raw_input("Enter prompt: ")
tokens_to_generate = int(input("Enter number of tokens to generate: "))
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req)
resp_sentences = json.load(response)
print("Megatron Response: ")
print(resp_sentences["text"][0])
sentence = input("Enter prompt: ")
tokens_to_generate = int(eval(input("Enter number of tokens to generate: ")))

data = {"prompts": [sentence], "tokens_to_generate": tokens_to_generate}
response = requests.put(url, data=json.dumps(data), headers=headers)

if response.status_code != 200:
print(f"Error {response.status_code}: {response.json()['message']}")
else:
print("Megatron Response: ")
print(response.json()['text'][0])

0 comments on commit 3c76018

Please sign in to comment.