Skip to content

Commit

Permalink
translate_hf uses dataloader and accelerate
Browse files Browse the repository at this point in the history
  • Loading branch information
onadegibert committed Sep 20, 2024
1 parent 9f60d34 commit 7d0ae7b
Showing 1 changed file with 80 additions and 44 deletions.
124 changes: 80 additions & 44 deletions pipeline/translate/translate_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import torch
import time
import ast
from accelerate import Accelerator, DataLoaderConfiguration
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import importlib

torch.cuda.empty_cache() # Clear unused memory

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -25,9 +31,29 @@ def parse_args():
parser.add_argument('config', type=str, help="Specific configuration for decoding")
return parser.parse_args()

def convert_simple_dict(d):
"""Convert numeric strings to integers or floats in a flat dictionary."""
return {key: ast.literal_eval(value) if isinstance(value, str) and value.isdigit() else value for key, value in d.items()}

class TokenizedDataset(Dataset):
def __init__(self, tokenized_inputs):
self.tokenized_inputs = tokenized_inputs

def __len__(self):
return len(self.tokenized_inputs['input_ids'])

def __getitem__(self, idx):
return {key: val[idx] for key, val in self.tokenized_inputs.items()}

def main():
#os.environ['HF_HOME'] = args.modeldir
args = parse_args()
os.environ['HF_HOME'] = args.modeldir

# Create a DataLoaderConfiguration object
dataloader_config = DataLoaderConfiguration(split_batches=True)

# Pass the config to the Accelerator
accelerator = Accelerator(device_placement=True, dataloader_config=dataloader_config)

print(f"Translating {args.filein} from {args.src} to {args.trg} with {args.modelname}...")

Expand All @@ -46,87 +72,97 @@ def main():
# Get the class from the module
model_class = getattr(module, class_name)

model = model_class.from_pretrained(model_name, trust_remote_code=True).to(device)

model = model_class.from_pretrained(model_name, trust_remote_code=True, device_map='auto')
model = accelerator.prepare(model) # Prepare model for distributed inference

# Mapping target languages
src_lang = lang_tags.get(args.src, None)
tgt_lang = lang_tags.get(args.trg, None)

if args.langinfo in ["True","true","1"]:

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, src_lang=src_lang, tgt_lang=trg_lang)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, src_lang=src_lang, tgt_lang=tgt_lang, use_fast=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True )
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)

num_return_sequences=8

if args.config == "default":
config=dict()
else:
config=ast.literal_eval(args.config)
config=convert_simple_dict(ast.literal_eval(args.config))

print("Starting translations...")
print("Tokenizing...")

# Read the input text
with open(args.filein, 'r', encoding='utf-8') as infile:
text = infile.readlines()

# Prepare for batch processing
# Format sentences with prompt
formatted_text = [prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=t) for t in text]

# Tokenize all the inputs at once
tokenized_inputs = tokenizer(formatted_text, return_tensors='pt', padding=True).to(accelerator.device)

# Prepare dataset and dataloader
dataset = TokenizedDataset(tokenized_inputs)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

print("Starting translations...")

# Accumulate multiple sentences in memory and write them to the file in larger batches
buffer_size = 1000000
buffer = []

# Open the output file in append mode
with open(args.fileout, 'a', encoding='utf-8') as outfile:
start_time = time.time() # Start time
# Perform the translation with progress print statements
for i in range(0, len(text), batch_size):
batch = text[i:i+batch_size]
input_texts=[prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=input_text) for input_text in batch]
print("Sample source sentence after prompt formatting:\n", input_texts[0])
inputs=tokenizer(input_texts, return_tensors="pt",padding=True).to(device)
with open(args.fileout, 'w', encoding='utf-8') as outfile:
start_time = time.time()
sentence_counter = 0

for batch in accelerator.prepare(dataloader):

# Generate output
translated_batch = model.generate(
**inputs,
**batch,
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences,
**config,
)

# Decode the output
translated_batch = tokenizer.batch_decode(translated_batch, skip_special_tokens=True)

# Write each translated sentence to the output file incrementally
i = 0 # Initialize 'i' outside the loop
sentence_counter = 0 # Counter to track every 8 sentences

for sentence in translated_batch:
# Remove prompt before writing out
if prompt != "{source}":
print("source text:",batch[i])
print("translation:",sentence)
curr_prompt=prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=batch[i])
print("prompt:",curr_prompt)
sentence=sentence.replace(curr_prompt,"")
print("fixed translation:",sentence)

outfile.write(f"{i} ||| {sentence}\n")
sentence_counter += 1

# Increment 'i' every 8 sentences
if sentence_counter % num_return_sequences == 0:
i += 1

# Write each translated sentence to the buffer
for i, sentence in enumerate(translated_batch):
curr_prompt = prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=text[i])
sentence = sentence.replace(curr_prompt, "")

# Add to buffer
buffer.append(f"{sentence_counter} ||| {sentence}\n")

# Increment sentence counter every num_return_sequences sentences
if (i + 1) % num_return_sequences == 0:
sentence_counter += 1

# When buffer is full, write it to file and clear the buffer
if len(buffer) >= buffer_size:
outfile.writelines(buffer) # Write buffer to file
buffer = [] # Clear the buffer

# Print progress every 50 sentences
if sentence_counter % 50 == 0:
print(f"Translated {sentence_counter} sentences...")

end_time = time.time() # End time
# If there are any remaining sentences in the buffer, flush them to the file
if buffer:
outfile.writelines(buffer)

end_time = time.time()
total_time = end_time - start_time
translations_per_second = len(text) / total_time if total_time > 0 else float('inf')


# Final progress print
print(f"Translation complete. Translating {len(text)} sentences took {total_time} seconds.")
print(f"Translation complete. Translating {len(text)} sentences took {total_time:.2f} seconds.")
print(f"{translations_per_second:.2f} translations/second")

if __name__ == "__main__":
main()
main()

0 comments on commit 7d0ae7b

Please sign in to comment.