diff --git a/.gitignore b/.gitignore index 5b1de5ab8..dd5f7af7b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ __pycache__ .mypy_cache/ models/ +checkpoint +samples \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..8fc2af403 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +FROM tensorflow/tensorflow:1.15.2-py3-jupyter + +# setup environment language +ENV LANG=C.UTF-8 + +# copy requirements.txt into image +COPY requirements.txt requirements.txt + +# update and upgrade packages and pip and install python libraries +RUN apt-get update && apt-get upgrade -y \ +&& apt-get install -y apt-utils \ +&& pip3 install --upgrade pip \ +&& pip3 install -r requirements.txt \ +&& rm requirements.txt diff --git a/Dockerfile.cpu b/Dockerfile.cpu deleted file mode 100644 index a02d2b320..000000000 --- a/Dockerfile.cpu +++ /dev/null @@ -1,9 +0,0 @@ -FROM tensorflow/tensorflow:1.12.0-py3 - -ENV LANG=C.UTF-8 -RUN mkdir /gpt-2 -WORKDIR /gpt-2 -ADD . /gpt-2 -RUN pip3 install -r requirements.txt -RUN python3 download_model.py 117M -RUN python3 download_model.py 345M diff --git a/Dockerfile.gpu b/Dockerfile.gpu deleted file mode 100644 index b3f87db14..000000000 --- a/Dockerfile.gpu +++ /dev/null @@ -1,18 +0,0 @@ -FROM tensorflow/tensorflow:1.12.0-gpu-py3 - -# nvidia-docker 1.0 -LABEL com.nvidia.volumes.needed="nvidia_driver" -LABEL com.nvidia.cuda.version="${CUDA_VERSION}" - -# nvidia-container-runtime -ENV NVIDIA_VISIBLE_DEVICES=all \ - NVIDIA_DRIVER_CAPABILITIES=compute,utility \ - NVIDIA_REQUIRE_CUDA="cuda>=8.0" \ - LANG=C.UTF-8 - -RUN mkdir /gpt-2 -WORKDIR /gpt-2 -ADD . /gpt-2 -RUN pip3 install -r requirements.txt -RUN python3 download_model.py 117M -RUN python3 download_model.py 345M diff --git a/README.md b/README.md index e38f2c1ce..e3aabc22f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ + +Reference: ["Beginner’s Guide to Retrain GPT-2 (117M) to Generate Custom Text Content"](https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f) + # gpt-2 Code from the paper ["Language Models are Unsupervised Multitask Learners"](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf). @@ -30,6 +33,58 @@ See [DEVELOPERS.md](./DEVELOPERS.md) See [CONTRIBUTORS.md](./CONTRIBUTORS.md) +## Fine tuning on custom datasets + +To retrain GPT-2 117M model on a custom text dataset: + +``` +PYTHONPATH=src ./train.py --dataset +``` + +If you want to precompute the dataset's encoding for multiple runs, you can instead use: + +``` +PYTHONPATH=src ./encode.py /path/to/encoded.npz +PYTHONPATH=src ./train.py --dataset /path/to/encoded.npz +``` + +Make sure `cudnn` is installed. [Some have reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py` runs without it but has worse memory usage and might OOM. + +### Gradient Checkpointing + +https://github.com/openai/gradient-checkpointing is included to reduce the memory requirements of the model, and can be enabled by `--memory_saving_gradients`. The checkpoints are currently chosen manually (poorly) by just adding layer 10 to the 'checkpoints' collection in model.py. `--memory_saving_gradients` is enabled by default for training the 345M model. + +### Validation loss + +Set `--val_every` to a number of steps `N > 0`, and "validation" loss against a fixed sample of the dataset will be calculated every N steps to get a better sense of training progress. N around 200 suggested. You can set `--val_dataset` to choose a separate validation dataset, otherwise it defaults to a sample from the train dataset (so not a real cross-validation loss!). + +### Optimizer + +You can use SGD instead of Adam with `--optimizer sgd`. This also helps conserve memory when training the 345M model. Note: the learning rate needs to be adjusted for SGD, due to not having Adam's gradient normalization (0.0006 seems to be a good number from some experiments). + +### Multi gpu (out of date) + +To do distributed on multiple GPUs or machines using Horovod: + +``` +mpirun -np 4 \ + -H localhost:4 \ + -bind-to none -map-by slot \ + -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \ + -x PYTHONPATH=src \ + -mca pml ob1 -mca btl ^openib \ + /home/jovyan/gpt-2/train-horovod.py --dataset encoded.npz +``` + +## GPT-2 samples + +| WARNING: Samples are unfiltered and may contain offensive content. | +| --- | + +While we have not yet released GPT-2 itself, you can see some samples from it in the `gpt-2-samples` folder. +We show unconditional samples with default settings (temperature 1 and no truncation), with temperature 0.7, and with truncation with top_k 40. +We show conditional samples, with contexts drawn from `WebText`'s test set, with default settings (temperature 1 and no truncation), with temperature 0.7, and with truncation with top_k 40. + ## Citation Please use the following bibtex entry: diff --git a/encode.py b/encode.py new file mode 100755 index 000000000..a9238d310 --- /dev/null +++ b/encode.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Usage: +# PYTHONPATH=src ./encode.py /path/to/output.npz +# PYTHONPATH=src ./train --dataset /path/to/output.npz + +import argparse +import numpy as np + +import encoder +from load_dataset import load_dataset + +parser = argparse.ArgumentParser( + description='Pre-encode text files into tokenized training set.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name') +parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size') +parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.') +parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).') +parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path') + +def main(): + args = parser.parse_args() + enc = encoder.get_encoder(args.model_name) + print('Reading files') + chunks = load_dataset(enc, args.in_text, args.combine, encoding=args.encoding) + print('Writing', args.out_npz) + np.savez_compressed(args.out_npz, *chunks) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index 2cc521d5e..b4a3ea703 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ fire>=0.1.3 regex==2017.4.5 requests==2.21.0 tqdm==4.31.1 +toposort==1.5 diff --git a/src/accumulate.py b/src/accumulate.py new file mode 100644 index 000000000..c5a4a81e8 --- /dev/null +++ b/src/accumulate.py @@ -0,0 +1,36 @@ +import argparse +import json +import os +import numpy as np +import tensorflow as tf +import time + + +class AccumulatingOptimizer(object): + def __init__(self, opt, var_list): + self.opt = opt + self.var_list = var_list + self.accum_vars = {tv : tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) + for tv in var_list} + self.total_loss = tf.Variable(tf.zeros(shape=[], dtype=tf.float32)) + self.count_loss = tf.Variable(tf.zeros(shape=[], dtype=tf.float32)) + + def reset(self): + updates = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_vars.values()] + updates.append(self.total_loss.assign(tf.zeros(shape=[], dtype=tf.float32))) + updates.append(self.count_loss.assign(tf.zeros(shape=[], dtype=tf.float32))) + with tf.control_dependencies(updates): + return tf.no_op() + + def compute_gradients(self, loss): + grads = self.opt.compute_gradients(loss, self.var_list) + updates = [self.accum_vars[v].assign_add(g) for (g,v) in grads] + updates.append(self.total_loss.assign_add(loss)) + updates.append(self.count_loss.assign_add(1.0)) + with tf.control_dependencies(updates): + return tf.no_op() + + def apply_gradients(self): + grads = [(g,v) for (v,g) in self.accum_vars.items()] + with tf.control_dependencies([self.opt.apply_gradients(grads)]): + return self.total_loss / self.count_loss diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py index 87e212972..d9e21319a 100755 --- a/src/generate_unconditional_samples.py +++ b/src/generate_unconditional_samples.py @@ -16,6 +16,7 @@ def sample_model( length=None, temperature=1, top_k=0, + top_p=0.0 ): """ Run the sample_model @@ -35,6 +36,8 @@ def sample_model( considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. + :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, + overriding top_k if set to a value > 0. A good setting is 0.9. """ enc = encoder.get_encoder(model_name) hparams = model.default_hparams() @@ -54,7 +57,7 @@ def sample_model( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'], batch_size=batch_size, - temperature=temperature, top_k=top_k + temperature=temperature, top_k=top_k, top_p=top_p )[:, 1:] saver = tf.train.Saver() @@ -72,4 +75,3 @@ def sample_model( if __name__ == '__main__': fire.Fire(sample_model) - diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 166171aaf..c1650bbea 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -16,6 +16,7 @@ def interact_model( length=None, temperature=1, top_k=0, + top_p=0.0 ): """ Interactively run the model @@ -34,6 +35,8 @@ def interact_model( considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value. + :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, + overriding top_k if set to a value > 0. A good setting is 0.9. """ if batch_size is None: batch_size = 1 @@ -57,7 +60,7 @@ def interact_model( hparams=hparams, length=length, context=context, batch_size=batch_size, - temperature=temperature, top_k=top_k + temperature=temperature, top_k=top_k, top_p=top_p ) saver = tf.train.Saver() @@ -84,4 +87,3 @@ def interact_model( if __name__ == '__main__': fire.Fire(interact_model) - diff --git a/src/load_dataset.py b/src/load_dataset.py new file mode 100644 index 000000000..499c3a9d0 --- /dev/null +++ b/src/load_dataset.py @@ -0,0 +1,83 @@ +import glob +import numpy as np +import os +import tensorflow as tf +import tqdm + + +def load_dataset(enc, path, combine, encoding=None): + paths = [] + if os.path.isfile(path): + # Simple file + paths.append(path) + elif os.path.isdir(path): + # Directory + for (dirpath, _, fnames) in os.walk(path): + for fname in fnames: + paths.append(os.path.join(dirpath, fname)) + else: + # Assume glob + paths = glob.glob(path) + + token_chunks = [] + raw_text = '' + for path in tqdm.tqdm(paths): + if path.endswith('.npz'): + # Pre-encoded + with np.load(path) as npz: + for item in npz.files: + token_chunks.append(npz[item]) + else: + # Plain text + with open(path, 'r', encoding=encoding) as fp: + raw_text += fp.read() + if len(raw_text) >= combine: + tokens = np.stack(enc.encode(raw_text)) + token_chunks.append(tokens) + raw_text = '' + else: + raw_text += '<|endoftext|>' + if raw_text: + tokens = np.stack(enc.encode(raw_text)) + token_chunks.append(tokens) + return token_chunks + + +def binary_search(f, lo, hi): + if f(lo) or not f(hi): + return None + while hi > lo + 1: + mid = (lo + hi) // 2 + if f(mid): + hi = mid + else: + lo = mid + return hi + + +class Sampler(object): + """Fairly samples a slice from a set of variable sized chunks. + + 'Fairly' means that the distribution is the same as sampling from one concatenated chunk, + but without crossing chunk boundaries.""" + + def __init__(self, chunks, seed=None): + self.chunks = chunks + self.total_size = sum(chunk.shape[0] for chunk in chunks) + self.boundaries = [0] + for i in range(len(chunks)): + self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0]) + self.rs = np.random.RandomState(seed=seed) + + def sample(self, length): + assert length < self.total_size // len( + self.chunks + ), "Dataset files are too small to sample {} tokens at a time".format( + length) + while True: + index = self.rs.randint(0, self.total_size - length - 1) + i = binary_search(lambda j: self.boundaries[j] > index, 0, + len(self.boundaries) - 1) - 1 + if self.boundaries[i + 1] > index + length: + within_chunk = index - self.boundaries[i] + return self.chunks[i][within_chunk:within_chunk + length] diff --git a/src/memory_saving_gradients.py b/src/memory_saving_gradients.py new file mode 100644 index 000000000..659691f49 --- /dev/null +++ b/src/memory_saving_gradients.py @@ -0,0 +1,387 @@ +from toposort import toposort +import contextlib +import numpy as np +import tensorflow as tf +import tensorflow.contrib.graph_editor as ge +import time +import sys +sys.setrecursionlimit(10000) +# refers back to current module if we decide to split helpers out +util = sys.modules[__name__] + +# getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" +setattr(tf.GraphKeys, "VARIABLES", "variables") + +# save original gradients since tf.gradient could be monkey-patched to point +# to our version +from tensorflow.python.ops import gradients as tf_gradients_lib +tf_gradients = tf_gradients_lib.gradients + +MIN_CHECKPOINT_NODE_SIZE=1024 # use lower value during testing + +# specific versions we can use to do process-wide replacement of tf.gradients +def gradients_speed(ys, xs, grad_ys=None, **kwargs): + return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) + +def gradients_memory(ys, xs, grad_ys=None, **kwargs): + return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs) + +def gradients_collection(ys, xs, grad_ys=None, **kwargs): + return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs) + +def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): + ''' + Authors: Tim Salimans & Yaroslav Bulatov + + memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" + by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) + + ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients + (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) + + 'checkpoints' can either be + - a list consisting of tensors from the forward pass of the neural net + that we should re-use when calculating the gradients in the backward pass + all other tensors that do not appear in this list will be re-computed + - a string specifying how this list should be determined. currently we support + - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, + so checkpointing them maximizes the running speed + (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) + - 'memory': try to minimize the memory usage + (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) + - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint + ''' + + # print("Calling memsaving gradients with", checkpoints) + if not isinstance(ys,list): + ys = [ys] + if not isinstance(xs,list): + xs = [xs] + + bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], + inclusive=True) + + debug_print("bwd_ops: %s", bwd_ops) + + # forward ops are all ops that are candidates for recomputation + fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], + inclusive=True, + within_ops=bwd_ops) + debug_print("fwd_ops: %s", fwd_ops) + + # exclude ops with no inputs + fwd_ops = [op for op in fwd_ops if op.inputs] + + # don't recompute xs, remove variables + xs_ops = _to_ops(xs) + fwd_ops = [op for op in fwd_ops if not op in xs_ops] + fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] + fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] + fwd_ops = [op for op in fwd_ops if not '/read' in op.name] + ts_all = ge.filter_ts(fwd_ops, True) # get the tensors + ts_all = [t for t in ts_all if '/read' not in t.name] + ts_all = set(ts_all) - set(xs) - set(ys) + + # construct list of tensors to checkpoint during forward pass, if not + # given as input + if type(checkpoints) is not list: + if checkpoints == 'collection': + checkpoints = tf.get_collection('checkpoints') + + elif checkpoints == 'speed': + # checkpoint all expensive ops to maximize running speed + checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') + + elif checkpoints == 'memory': + + # remove very small tensors and some weird ops + def fixdims(t): # tf.Dimension values are not compatible with int, convert manually + try: + return [int(e if e.value is not None else 64) for e in t] + except: + return [0] # unknown shape + ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] + ts_all = [t for t in ts_all if 'L2Loss' not in t.name] + ts_all = [t for t in ts_all if 'entropy' not in t.name] + ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] + ts_all = [t for t in ts_all if 'Switch' not in t.name] + ts_all = [t for t in ts_all if 'dropout' not in t.name] + # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 + ts_all = [t for t in ts_all if 'Cast' not in t.name] + + # filter out all tensors that are inputs of the backward graph + with util.capture_ops() as bwd_ops: + tf_gradients(ys, xs, grad_ys, **kwargs) + + bwd_inputs = [t for op in bwd_ops for t in op.inputs] + # list of tensors in forward graph that is in input to bwd graph + ts_filtered = list(set(bwd_inputs).intersection(ts_all)) + debug_print("Using tensors %s", ts_filtered) + + # try two slightly different ways of getting bottlenecks tensors + # to checkpoint + for ts in [ts_filtered, ts_all]: + + # get all bottlenecks in the graph + bottleneck_ts = [] + for t in ts: + b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) + f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) + # check that there are not shortcuts + b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) + f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) + if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): + bottleneck_ts.append(t) # we have a bottleneck! + else: + debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) + + # success? or try again without filtering? + if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! + break + + if not bottleneck_ts: + raise Exception('unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') + + # sort the bottlenecks + bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) + sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] + + # save an approximately optimal number ~ sqrt(N) + N = len(ts_filtered) + if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): + checkpoints = sorted_bottlenecks + else: + step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) + checkpoints = sorted_bottlenecks[step::step] + + else: + raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) + + checkpoints = list(set(checkpoints).intersection(ts_all)) + + # at this point automatic selection happened and checkpoints is list of nodes + assert isinstance(checkpoints, list) + + debug_print("Checkpoint nodes used: %s", checkpoints) + # better error handling of special cases + # xs are already handled as checkpoint nodes, so no need to include them + xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) + if xs_intersect_checkpoints: + debug_print("Warning, some input nodes are also checkpoint nodes: %s", + xs_intersect_checkpoints) + ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) + debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, + ys_intersect_checkpoints) + # saving an output node (ys) gives no benefit in memory while creating + # new edge cases, exclude them + if ys_intersect_checkpoints: + debug_print("Warning, some output nodes are also checkpoints nodes: %s", + format_ops(ys_intersect_checkpoints)) + + # remove initial and terminal nodes from checkpoints list if present + checkpoints = list(set(checkpoints) - set(ys) - set(xs)) + + # check that we have some nodes to checkpoint + # if not checkpoints: + # raise Exception('no checkpoints nodes found or given as input! ') + + # disconnect dependencies between checkpointed tensors + checkpoints_disconnected = {} + for x in checkpoints: + if x.op and x.op.name is not None: + grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") + else: + grad_node = tf.stop_gradient(x) + checkpoints_disconnected[x] = grad_node + + # partial derivatives to the checkpointed tensors and xs + ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], + stop_at_ts=checkpoints, within_ops=fwd_ops) + debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", + len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) + debug_print("ops_to_copy = %s", ops_to_copy) + debug_print("Processing list %s", ys) + copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) + for origin_op, op in info._transformed_ops.items(): + op._set_device(origin_op.node_def.device) + copied_ops = info._transformed_ops.values() + debug_print("Copied %s to %s", ops_to_copy, copied_ops) + ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) + debug_print("Rewired %s in place of %s restricted to %s", + checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) + + # get gradients with respect to current boundary + original x's + copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] + boundary = list(checkpoints_disconnected.values()) + dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) + debug_print("Got gradients %s", dv) + debug_print("for %s", copied_ys) + debug_print("with respect to %s", boundary+xs) + + inputs_to_do_before = [y.op for y in ys] + if grad_ys is not None: + inputs_to_do_before += grad_ys + wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] + my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) + + # partial derivatives to the checkpointed nodes + # dictionary of "node: backprop" for nodes in the boundary + d_checkpoints = {r: dr for r,dr in zip(checkpoints_disconnected.keys(), + dv[:len(checkpoints_disconnected)])} + # partial derivatives to xs (usually the params of the neural net) + d_xs = dv[len(checkpoints_disconnected):] + + # incorporate derivatives flowing through the checkpointed nodes + checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) + for ts in checkpoints_sorted_lists[::-1]: + debug_print("Processing list %s", ts) + checkpoints_other = [r for r in checkpoints if r not in ts] + checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] + + # copy part of the graph below current checkpoint node, stopping at + # other checkpoints nodes + ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) + debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", + len(ops_to_copy), fwd_ops, [r.op for r in ts], + checkpoints_other) + debug_print("ops_to_copy = %s", ops_to_copy) + if not ops_to_copy: # we're done! + break + copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) + for origin_op, op in info._transformed_ops.items(): + op._set_device(origin_op.node_def.device) + copied_ops = info._transformed_ops.values() + debug_print("Copied %s to %s", ops_to_copy, copied_ops) + ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) + debug_print("Rewired %s in place of %s restricted to %s", + checkpoints_disconnected_other, checkpoints_other, copied_ops) + + # gradient flowing through the checkpointed node + boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] + substitute_backprops = [d_checkpoints[r] for r in ts] + dv = tf_gradients(boundary, + checkpoints_disconnected_other+xs, + grad_ys=substitute_backprops, **kwargs) + debug_print("Got gradients %s", dv) + debug_print("for %s", boundary) + debug_print("with respect to %s", checkpoints_disconnected_other+xs) + debug_print("with boundary backprop substitutions %s", substitute_backprops) + + inputs_to_do_before = [d_checkpoints[r].op for r in ts] + wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] + my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) + + # partial derivatives to the checkpointed nodes + for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): + if dr is not None: + if d_checkpoints[r] is None: + d_checkpoints[r] = dr + else: + d_checkpoints[r] += dr + def _unsparsify(x): + if not isinstance(x, tf.IndexedSlices): + return x + assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" + indices = x.indices + while indices.shape.ndims < x.values.shape.ndims: + indices = tf.expand_dims(indices, -1) + return tf.scatter_nd(indices, x.values, x.dense_shape) + + # partial derivatives to xs (usually the params of the neural net) + d_xs_new = dv[len(checkpoints_other):] + for j in range(len(xs)): + if d_xs_new[j] is not None: + if d_xs[j] is None: + d_xs[j] = _unsparsify(d_xs_new[j]) + else: + d_xs[j] += _unsparsify(d_xs_new[j]) + + + return d_xs + +def tf_toposort(ts, within_ops=None): + all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops) + + deps = {} + for op in all_ops: + for o in op.outputs: + deps[o] = set(op.inputs) + sorted_ts = toposort(deps) + + # only keep the tensors from our original list + ts_sorted_lists = [] + for l in sorted_ts: + keep = list(set(l).intersection(ts)) + if keep: + ts_sorted_lists.append(keep) + + return ts_sorted_lists + +def fast_backward_ops(within_ops, seed_ops, stop_at_ts): + bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) + ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts]) + return list(ops) + +@contextlib.contextmanager +def capture_ops(): + """Decorator to capture ops created in the block. + with capture_ops() as ops: + # create some ops + print(ops) # => prints ops created. + """ + + micros = int(time.time()*10**6) + scope_name = str(micros) + op_list = [] + with tf.name_scope(scope_name): + yield op_list + + g = tf.get_default_graph() + op_list.extend(ge.select_ops(scope_name+"/.*", graph=g)) + +def _to_op(tensor_or_op): + if hasattr(tensor_or_op, "op"): + return tensor_or_op.op + return tensor_or_op + +def _to_ops(iterable): + if not _is_iterable(iterable): + return iterable + return [_to_op(i) for i in iterable] + +def _is_iterable(o): + try: + _ = iter(o) + except Exception: + return False + return True + +DEBUG_LOGGING=False +def debug_print(s, *args): + """Like logger.log, but also replaces all TensorFlow ops/tensors with their + names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug + + Usage: + debug_print("see tensors %s for %s", tensorlist, [1,2,3]) + """ + + if DEBUG_LOGGING: + formatted_args = [format_ops(arg) for arg in args] + print("DEBUG "+s % tuple(formatted_args)) + +def format_ops(ops, sort_outputs=True): + """Helper method for printing ops. Converts Tensor/Operation op to op.name, + rest to str(op).""" + + if hasattr(ops, '__iter__') and not isinstance(ops, str): + l = [(op.name if hasattr(op, "name") else str(op)) for op in ops] + if sort_outputs: + return sorted(l) + return l + else: + return ops.name if hasattr(ops, "name") else str(ops) + +def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before): + for op in wait_to_do_ops: + ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs] + ge.add_control_inputs(op, ci) diff --git a/src/model.py b/src/model.py index 230b83cc2..4e942d873 100644 --- a/src/model.py +++ b/src/model.py @@ -144,7 +144,7 @@ def positions_for(tokens, past_length): return expand_tile(past_length + tf.range(nsteps), batch_size) -def model(hparams, X, past=None, scope='model', reuse=False): +def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=reuse): results = {} batch, sequence = shape_list(X) @@ -162,6 +162,8 @@ def model(hparams, X, past=None, scope='model', reuse=False): assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) + if layer == 10: + tf.add_to_collection('checkpoints', h) presents.append(present) results['present'] = tf.stack(presents, axis=1) h = norm(h, 'ln_f') diff --git a/src/sample.py b/src/sample.py index c309ef0da..40d475e82 100644 --- a/src/sample.py +++ b/src/sample.py @@ -22,7 +22,21 @@ def _top_k(): ) -def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): +def top_p_logits(logits, p): + with tf.variable_scope('top_p_logits'): + logits_sort = tf.sort(logits, direction='DESCENDING') + probs_sort = tf.nn.softmax(logits_sort) + probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True) + logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like(logits_sort)*1000) # [batchsize, vocab] + min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True) # [batchsize, 1] + return tf.where( + logits < min_logits, + tf.ones_like(logits, dtype=logits.dtype) * -1e10, + logits, + ) + + +def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: @@ -49,7 +63,10 @@ def step(hparams, tokens, past=None): def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) - logits = top_k_logits(logits, k=top_k) + if top_p > 0.0: + logits = top_p_logits(logits, p=top_p) + else: + logits = top_k_logits(logits, k=top_k) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), diff --git a/train-horovod.py b/train-horovod.py new file mode 100644 index 000000000..bea32773f --- /dev/null +++ b/train-horovod.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# Usage: +# PYTHONPATH=src ./train --dataset + +import fire +import json +import os +import numpy as np +import tensorflow as tf +import random +import time + +import horovod.tensorflow as hvd + +import model, sample, encoder +from load_dataset import load_dataset, Sampler + +CHECKPOINT_DIR = 'checkpoint' +SAMPLE_DIR = 'samples' + +hvd.init() + +def maketree(path): + try: + os.makedirs(path) + except: + pass + + +def train_main(dataset, + model_name='117M', + seed=None, + batch_size=2, + sample_length=1023, + sample_num=1, + sample_every=4500, + run_name='run1', + restore_from='latest', + save_every=2000, + combine=50000): + + enc = encoder.get_encoder(model_name) + hparams = model.default_hparams() + with open(os.path.join('models', model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if sample_length is None: + sample_length = hparams.n_ctx // 2 + elif sample_length > hparams.n_ctx: + raise ValueError( + "Can't get samples longer than window size: %s" % hparams.n_ctx) + + # TF config + + config = tf.ConfigProto() + config.gpu_options.visible_device_list = str(hvd.local_rank()) + config.gpu_options.allow_growth = True + + with tf.Session(config=config) as sess: + context = tf.placeholder(tf.int32, [batch_size, None]) + np.random.seed(seed) + tf.set_random_seed(seed) + output = model.model(hparams=hparams, X=context) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=context[:, 1:], logits=output['logits'][:, :-1])) + + tf_sample = sample.sample_sequence( + hparams=hparams, + length=sample_length, + context=context, + batch_size=batch_size, + temperature=0.8, + top_k=40) + + train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] + + opt = tf.train.AdamOptimizer() + opt = hvd.DistributedOptimizer(opt) + train_op = opt.minimize(loss, var_list=train_vars) + + # Horovod: broadcast initial variable states from rank 0 to all other processes. + # This is necessary to ensure consistent initialization of all workers when + # training is started with random weights or restored from a checkpoint. + bcast = hvd.broadcast_global_variables(0) + + saver = tf.train.Saver( + var_list=train_vars, + max_to_keep=5, + keep_checkpoint_every_n_hours=2) + + sess.run(tf.global_variables_initializer()) + + + if restore_from == 'latest': + ckpt = tf.train.latest_checkpoint( + os.path.join(CHECKPOINT_DIR, run_name)) + if ckpt is None: + # Get fresh GPT weights if new run. + ckpt = tf.train.latest_checkpoint( + os.path.join('models', model_name)) + elif restore_from == 'fresh': + ckpt = tf.train.latest_checkpoint( + os.path.join('models', model_name)) + else: + ckpt = tf.train.latest_checkpoint(restore_from) + print(str(hvd.local_rank()), 'Loading checkpoint', ckpt) + saver.restore(sess, ckpt) + + bcast.run() + + print(str(hvd.local_rank()), 'Loading dataset...') + chunks = load_dataset(enc, dataset, combine) + data_sampler = Sampler(chunks) + print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens') + print(str(hvd.local_rank()), 'Training...') + + counter = 1 + if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')): + # Load the step number if we're resuming a run + # Add 1 so we don't immediately try to save again + with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), + 'r') as fp: + counter = int(fp.read()) + 1 + + def save(): + maketree(os.path.join(CHECKPOINT_DIR, run_name)) + print( + 'Saving', + os.path.join(CHECKPOINT_DIR, run_name, + 'model-{}').format(counter)) + saver.save( + sess, + os.path.join(CHECKPOINT_DIR, run_name, 'model'), + global_step=counter) + with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), + 'w') as fp: + fp.write(str(counter) + '\n') + + def generate_samples(): + context_tokens = data_sampler.sample(1) + all_text = [] + index = 0 + while index < sample_num: + out = sess.run( + tf_sample, feed_dict={context: batch_size*[context_tokens]}) + for i in range(min(sample_num - index, batch_size)): + text = enc.decode(out[i]) + text = '======== SAMPLE {} ========\n{}\n'.format(index + 1, text) + all_text.append(text) + index += 1 + print(text) + maketree(os.path.join(SAMPLE_DIR, run_name)) + with open( + os.path.join(SAMPLE_DIR, run_name, + 'samples-{}').format(counter), 'w') as fp: + fp.write('\n'.join(all_text)) + + avg_loss = (0.0, 0.0) + start_time = time.time() + + try: + while True: + + batch = [data_sampler.sample(1024) for _ in range(batch_size)] + + _, lv = sess.run((train_op, loss), feed_dict={context: batch}) + + avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0) + + if hvd.rank() == 0: + if counter % save_every == 0: + save() + if counter % sample_every == 0: + generate_samples() + + print( + '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' + .format( + counter=counter, + time=time.time() - start_time, + loss=lv, + avg=avg_loss[0] / avg_loss[1])) + + counter += 1 + + except KeyboardInterrupt: + print('interrupted') + if hvd.rank() == 0: + save() + + +if __name__ == '__main__': + fire.Fire(train_main) diff --git a/train.py b/train.py new file mode 100755 index 000000000..57e4ef92a --- /dev/null +++ b/train.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Usage: +# PYTHONPATH=src ./train --dataset + +import argparse +import json +import os +import numpy as np +import tensorflow as tf +import time +import tqdm +from tensorflow.core.protobuf import rewriter_config_pb2 + +import model, sample, encoder +from load_dataset import load_dataset, Sampler +from accumulate import AccumulatingOptimizer +import memory_saving_gradients + +CHECKPOINT_DIR = 'checkpoint' +SAMPLE_DIR = 'samples' + + +parser = argparse.ArgumentParser( + description='Fine-tune GPT-2 on your custom dataset.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).') +parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name') +parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size') +parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.') + +parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size') +parser.add_argument('--learning_rate', metavar='LR', type=float, default=0.00002, help='Learning rate for Adam') +parser.add_argument('--accumulate_gradients', metavar='N', type=int, default=1, help='Accumulate gradients across N minibatches.') +parser.add_argument('--memory_saving_gradients', default=False, action='store_true', help='Use gradient checkpointing to reduce vram usage.') +parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.') +parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer. .') +parser.add_argument('--noise', type=float, default=0.0, help='Add noise to input training data to regularize against typos.') + +parser.add_argument('--top_k', type=int, default=40, help='K for top-k sampling.') +parser.add_argument('--top_p', type=float, default=0.0, help='P for top-p sampling. Overrides top_k if set > 0.') + +parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file') +parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/') +parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps') +parser.add_argument('--sample_length', metavar='TOKENS', type=int, default=1023, help='Sample this many tokens') +parser.add_argument('--sample_num', metavar='N', type=int, default=1, help='Generate this many samples') +parser.add_argument('--save_every', metavar='N', type=int, default=1000, help='Write a checkpoint every N steps') + +parser.add_argument('--val_dataset', metavar='PATH', type=str, default=None, help='Dataset for validation loss, defaults to --dataset.') +parser.add_argument('--val_batch_size', metavar='SIZE', type=int, default=2, help='Batch size for validation.') +parser.add_argument('--val_batch_count', metavar='N', type=int, default=40, help='Number of batches for validation.') +parser.add_argument('--val_every', metavar='STEPS', type=int, default=0, help='Calculate validation loss every STEPS steps.') + + +def maketree(path): + try: + os.makedirs(path) + except: + pass + + +def randomize(context, hparams, p): + if p > 0: + mask = tf.random.uniform(shape=tf.shape(context)) < p + noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32) + return tf.where(mask, noise, context) + else: + return context + + +def main(): + args = parser.parse_args() + enc = encoder.get_encoder(args.model_name) + hparams = model.default_hparams() + with open(os.path.join('models', args.model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if args.sample_length > hparams.n_ctx: + raise ValueError( + "Can't get samples longer than window size: %s" % hparams.n_ctx) + + if args.model_name == '345M': + args.memory_saving_gradients = True + if args.optimizer == 'adam': + args.only_train_transformer_layers = True + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF + with tf.Session(config=config) as sess: + context = tf.placeholder(tf.int32, [args.batch_size, None]) + context_in = randomize(context, hparams, args.noise) + output = model.model(hparams=hparams, X=context_in) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=context[:, 1:], logits=output['logits'][:, :-1])) + + if args.val_every > 0: + val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) + val_output = model.model(hparams=hparams, X=val_context) + val_loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) + val_loss_summary = tf.summary.scalar('val_loss', val_loss) + + + tf_sample = sample.sample_sequence( + hparams=hparams, + length=args.sample_length, + context=context, + batch_size=args.batch_size, + temperature=1.0, + top_k=args.top_k, + top_p=args.top_p) + + all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] + train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars + + if args.optimizer == 'adam': + opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) + elif args.optimizer == 'sgd': + opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) + else: + exit('Bad optimizer:', args.optimizer) + + if args.accumulate_gradients > 1: + if args.memory_saving_gradients: + exit("Memory saving gradients are not implemented for gradient accumulation yet.") + opt = AccumulatingOptimizer( + opt=opt, + var_list=train_vars) + opt_reset = opt.reset() + opt_compute = opt.compute_gradients(loss) + opt_apply = opt.apply_gradients() + summary_loss = tf.summary.scalar('loss', opt_apply) + else: + if args.memory_saving_gradients: + opt_grads = memory_saving_gradients.gradients(loss, train_vars) + else: + opt_grads = tf.gradients(loss, train_vars) + opt_grads = list(zip(opt_grads, train_vars)) + opt_apply = opt.apply_gradients(opt_grads) + summary_loss = tf.summary.scalar('loss', loss) + + summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) + summaries = tf.summary.merge([summary_lr, summary_loss]) + + summary_log = tf.summary.FileWriter( + os.path.join(CHECKPOINT_DIR, args.run_name)) + + saver = tf.train.Saver( + var_list=all_vars, + max_to_keep=5, + keep_checkpoint_every_n_hours=2) + sess.run(tf.global_variables_initializer()) + + if args.restore_from == 'latest': + ckpt = tf.train.latest_checkpoint( + os.path.join(CHECKPOINT_DIR, args.run_name)) + if ckpt is None: + # Get fresh GPT weights if new run. + ckpt = tf.train.latest_checkpoint( + os.path.join('models', args.model_name)) + elif args.restore_from == 'fresh': + ckpt = tf.train.latest_checkpoint( + os.path.join('models', args.model_name)) + else: + ckpt = tf.train.latest_checkpoint(args.restore_from) + print('Loading checkpoint', ckpt) + saver.restore(sess, ckpt) + + print('Loading dataset...') + chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) + data_sampler = Sampler(chunks) + if args.val_every > 0: + if args.val_dataset: + val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) + else: + val_chunks = chunks + print('dataset has', data_sampler.total_size, 'tokens') + print('Training...') + + if args.val_every > 0: + # Sample from validation set once with fixed seed to make + # it deterministic during training as well as across runs. + val_data_sampler = Sampler(val_chunks, seed=1) + val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] + for _ in range(args.val_batch_count)] + + counter = 1 + counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') + if os.path.exists(counter_path): + # Load the step number if we're resuming a run + # Add 1 so we don't immediately try to save again + with open(counter_path, 'r') as fp: + counter = int(fp.read()) + 1 + + def save(): + maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) + print( + 'Saving', + os.path.join(CHECKPOINT_DIR, args.run_name, + 'model-{}').format(counter)) + saver.save( + sess, + os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), + global_step=counter) + with open(counter_path, 'w') as fp: + fp.write(str(counter) + '\n') + + def generate_samples(): + print('Generating samples...') + context_tokens = data_sampler.sample(1) + all_text = [] + index = 0 + while index < args.sample_num: + out = sess.run( + tf_sample, + feed_dict={context: args.batch_size * [context_tokens]}) + for i in range(min(args.sample_num - index, args.batch_size)): + text = enc.decode(out[i]) + text = '======== SAMPLE {} ========\n{}\n'.format( + index + 1, text) + all_text.append(text) + index += 1 + print(text) + maketree(os.path.join(SAMPLE_DIR, args.run_name)) + with open( + os.path.join(SAMPLE_DIR, args.run_name, + 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp: + fp.write('\n'.join(all_text)) + + def validation(): + print('Calculating validation loss...') + losses = [] + for batch in tqdm.tqdm(val_batches): + losses.append(sess.run(val_loss, feed_dict={val_context: batch})) + v_val_loss = np.mean(losses) + v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) + summary_log.add_summary(v_summary, counter) + summary_log.flush() + print( + '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' + .format( + counter=counter, + time=time.time() - start_time, + loss=v_val_loss)) + + def sample_batch(): + return [data_sampler.sample(1024) for _ in range(args.batch_size)] + + + avg_loss = (0.0, 0.0) + start_time = time.time() + + try: + while True: + if counter % args.save_every == 0: + save() + if counter % args.sample_every == 0: + generate_samples() + if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): + validation() + + if args.accumulate_gradients > 1: + sess.run(opt_reset) + for _ in range(args.accumulate_gradients): + sess.run( + opt_compute, feed_dict={context: sample_batch()}) + (v_loss, v_summary) = sess.run((opt_apply, summaries)) + else: + (_, v_loss, v_summary) = sess.run( + (opt_apply, loss, summaries), + feed_dict={context: sample_batch()}) + + summary_log.add_summary(v_summary, counter) + + avg_loss = (avg_loss[0] * 0.99 + v_loss, + avg_loss[1] * 0.99 + 1.0) + + print( + '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' + .format( + counter=counter, + time=time.time() - start_time, + loss=v_loss, + avg=avg_loss[0] / avg_loss[1])) + + counter += 1 + except KeyboardInterrupt: + print('interrupted') + save() + + +if __name__ == '__main__': + main()