tensorflow版本的tansformer训练IWSLT数据集

时间:2022-07-23
本文章向大家介绍tensorflow版本的tansformer训练IWSLT数据集,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

代码来源:https://github.com/Kyubyong/transformer

1、git clone https://github.com/Kyubyong/transformer.git

2、pip install sentencepiece

3、下载数据集

进入到tansformer目录下,输入:sh download.sh

运行成功之后,会有这么一些文件:

de-en.de.xml中内容大致是这个样子的:

<?xml version="1.0" encoding="UTF-8"?>
<mteval>
<srcset setid="iwslt2016-dev2010" srclang="german">
<doc docid="69" genre="lectures">
<url>http://www.ted.com/talks/lang/de/wade_davis_on_endangered_cultures.html</url>
<description>Mit atemberaubenden Fotos und Geschichten feiert der National Geographic- Forschungsreisende Wade Davis, die außergewöhnliche Vielfalt der Ureinwohner der Welt, welche in alarmierender Anzahl von unserem Planeten verschwinden.</description>
<keywords>anthropology,culture,environment,film,global issues,language,photography</keywords>
<talkid>69</talkid>
<title>Wade Davis über gefährdete Kulturen</title>
<seg id="1"> Wissen Sie, eines der großen Vernügen beim Reisen und eine der Freuden bei der ethnographischen Forschung ist, gemeinsam mit den Menschen zu leben, die sich noch an die alten Tage erinnern können. Die ihre Vergangenheit noch immer im Wind spüren, sie auf vom Regen geglätteten Steinen berühren, sie in den bitteren Blättern der Pflanzen schmecken. </seg>
<seg id="2"> Einfach das Wissen, dass Jaguar-Schamanen noch immer jenseits der Milchstraße reisen oder die Bedeutung der Mythen der Ältesten der Inuit noch voller Bedeutung sind, oder dass im Himalaya die Buddhisten noch immer den Atem des Dharma verfolgen, bedeutet, sich die zentrale Offenbarung der Anthropologie ins Gedächtnis zu rufen, das ist der Gedanke, dass die Welt, in der wir leben, nicht in einem absoluten Sinn existiert, sondern nur als ein Modell der Realität, als eine Folge einer Gruppe von bestimmten Möglichkeiten der Anpassung die unsere Ahnen, wenngleich erfolgreich, vor vielen Generationen wählten. </seg>
<seg id="3"> Und natürlich teilen wir alle dieselben Anpassungsnotwendigkeiten. </seg>
<seg id="4"> Wir werden alle geboren. Wir bringen Kinder zur Welt. </seg>
<seg id="5"> Wir durchlaufen Initiationsrituale. </seg>
<seg id="6"> Wir müssen uns mit der unaufhaltsamen Trennung durch den Tod auseinandersetzen und somit sollte es uns nicht überraschen, dass wir alle singen, tanzen und und Kunst hervorbringen. </seg>
<seg id="7"> Aber interessant ist der einzigartige Tonfall des Liedes, der Rhythmus des Tanzes in jeder Kultur. </seg>
<seg id="8"> Dabei spielt es keine Rolle, ob es sich um die Penan in den Wäldern von Borneo handelt, oder die Voodoo-Akolythen in Haiti, oder die Krieger in der Kaisut-Wüste von Nordkenia, die Curanderos in den Anden, oder eine Karawanserei mitten in der Sahara. Dies ist zufällig der Kollege, mit dem ich vor einem Monat in die Wüste gereist bin. Oder selbst ein Yak-Hirte an den Hängen des Qomolangma, Everest, der Gottmutter der Welt. </seg>
<seg id="9"> All diese Menschen lehren uns, dass es noch andere Existenzmöglichkeiten, andere Denkweisen, andere Wege zur Orientierung auf der Erde gibt. </seg>
<seg id="10"> Und das ist eine Vorstellung, die, wenn man darüber nachdenkt, einen nur mit Hoffnung erfüllen kann. </seg>
<seg id="11"> Zusammen bilden die unzähligen Kulturen der Welt ein Netz aus spirituellem und kulturellem Leben, das die Erde umhüllt und für das Wohl der Erde genauso wichtig ist, wie das biologische Lebensnetz, das man als Biosphäre kennt. </seg>
<seg id="12"> Man kann sich dieses kulturelle Lebensnetz als eine Ethnosphäre vorstellen. Ethnosphäre kann dabei als die Gesamtsumme aller Gedanken und Träume, Mythen Ideen, Inspirationen und Intuitionen, die von der menschlichen Vorstellungskraft seit den Anfängen des Bewusstseins hervorgebracht wurden, definiert werden. </seg>
<seg id="13"> Die Ethnosphäre ist das großartige Vermächtnis der Menschheit. </seg>
<seg id="14"> Sie ist das Symbol all dessen, was wir sind und wozu wir als erstaunlich wissbegierige Spezies fähig sind. </seg>
<seg id="15"> Und genauso wie die Biosphäre stark abgetragen wurde, geschah dies mit der Ethnosphäre -- nur mit noch größerer Geschwindigkeit. </seg>
<seg id="16"> Kein Biologe würde zum Beispiel wagen zu behaupten, dass 50% oder mehr aller Arten kurz vor dem Aussterben sind, da es einfach nicht stimmt. Und doch, dieses -- das apokalyptischste Szenarium auf dem Gebiet der biologischen Vielfalt -- entspricht kaum dem, was uns als optimistischstes Szenarium auf dem Gebiet der kulturellen Vielfalt bekannt ist. </seg>
<seg id="17"> Und der entscheidende Indikator dafür ist das Aussterben der Sprachen. </seg>

4、创建训练集、验证集、测试集

python prepro.py --vocab_size 8000

部分运行结果:

trainer_interface.cc(615) LOG(INFO) Saving model: iwslt2016/segmented/bpe.model
trainer_interface.cc(626) LOG(INFO) Saving vocabs: iwslt2016/segmented/bpe.vocab
INFO:root:# Load trained bpe model
INFO:root:# Segment
INFO:root:Let's see how segmented data look like
train1: ▁David ▁G all o : ▁Das ▁ist ▁Bill ▁L ange . ▁Ich ▁bin ▁Da ve ▁G all o .

train2: ▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o .

eval1: ▁Als ▁ich ▁11 ▁Jahre ▁alt ▁war , ▁wurde ▁ich ▁eines ▁Morgen s ▁von ▁den ▁Kl ängen ▁h eller ▁Freude ▁ge we ckt .

eval2: ▁When ▁I ▁was ▁11 , ▁I ▁remember ▁w aking ▁up ▁one ▁morning ▁to ▁the ▁sound ▁of ▁j oy ▁in ▁my ▁house .

test1: ▁Als ▁ich ▁in ▁meinen ▁20 ern ▁war , ▁hatte ▁ich ▁meine ▁erste ▁Psych other ap ie - P at ient in .

INFO:root:Done

运行之后会有:

prepro.py中的内容如下:

# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer.

Preprocess the iwslt 2016 datasets.
'''

import os
import errno
import sentencepiece as spm
import re
from hparams import Hparams
import logging

logging.basicConfig(level=logging.INFO)

def prepro(hp):
    """Load raw data -> Preprocessing -> Segmenting with sentencepice
    hp: hyperparams. argparse.
    """
    logging.info("# Check if raw files exist")
    train1 = "iwslt2016/de-en/train.tags.de-en.de"
    train2 = "iwslt2016/de-en/train.tags.de-en.en"
    eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml"
    eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml"
    test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml"
    test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml"
    for f in (train1, train2, eval1, eval2, test1, test2):
        if not os.path.isfile(f):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f)

    logging.info("# Preprocessing")
    # train
    _prepro = lambda x:  [line.strip() for line in open(x, 'r').read().split("n") 
                      if not line.startswith("<")]
    prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2)
    assert len(prepro_train1)==len(prepro_train2), "Check if train source and target files match."

    # eval
    _prepro = lambda x: [re.sub("<[^>]+>", "", line).strip() 
                     for line in open(x, 'r').read().split("n") 
                     if line.startswith("<seg id")]
    prepro_eval1, prepro_eval2 = _prepro(eval1), _prepro(eval2)
    assert len(prepro_eval1) == len(prepro_eval2), "Check if eval source and target files match."

    # test
    prepro_test1, prepro_test2 = _prepro(test1), _prepro(test2)
    assert len(prepro_test1) == len(prepro_test2), "Check if test source and target files match."

    logging.info("Let's see how preprocessed data look like")
    logging.info("prepro_train1:", prepro_train1[0])
    logging.info("prepro_train2:", prepro_train2[0])
    logging.info("prepro_eval1:", prepro_eval1[0])
    logging.info("prepro_eval2:", prepro_eval2[0])
    logging.info("prepro_test1:", prepro_test1[0])
    logging.info("prepro_test2:", prepro_test2[0])

    logging.info("# write preprocessed files to disk")
    os.makedirs("iwslt2016/prepro", exist_ok=True)
    def _write(sents, fname):
        with open(fname, 'w') as fout:
            fout.write("n".join(sents))

    _write(prepro_train1, "iwslt2016/prepro/train.de")
    _write(prepro_train2, "iwslt2016/prepro/train.en")
    _write(prepro_train1+prepro_train2, "iwslt2016/prepro/train")
    _write(prepro_eval1, "iwslt2016/prepro/eval.de")
    _write(prepro_eval2, "iwslt2016/prepro/eval.en")
    _write(prepro_test1, "iwslt2016/prepro/test.de")
    _write(prepro_test2, "iwslt2016/prepro/test.en")

    logging.info("# Train a joint BPE model with sentencepiece")
    os.makedirs("iwslt2016/segmented", exist_ok=True)
    train = '--input=iwslt2016/prepro/train --pad_id=0 --unk_id=1 
             --bos_id=2 --eos_id=3
             --model_prefix=iwslt2016/segmented/bpe --vocab_size={} 
             --model_type=bpe'.format(hp.vocab_size)
    spm.SentencePieceTrainer.Train(train)

    logging.info("# Load trained bpe model")
    sp = spm.SentencePieceProcessor()
    sp.Load("iwslt2016/segmented/bpe.model")

    logging.info("# Segment")
    def _segment_and_write(sents, fname):
        with open(fname, "w") as fout:
            for sent in sents:
                pieces = sp.EncodeAsPieces(sent)
                fout.write(" ".join(pieces) + "n")

    _segment_and_write(prepro_train1, "iwslt2016/segmented/train.de.bpe")
    _segment_and_write(prepro_train2, "iwslt2016/segmented/train.en.bpe")
    _segment_and_write(prepro_eval1, "iwslt2016/segmented/eval.de.bpe")
    _segment_and_write(prepro_eval2, "iwslt2016/segmented/eval.en.bpe")
    _segment_and_write(prepro_test1, "iwslt2016/segmented/test.de.bpe")

    logging.info("Let's see how segmented data look like")
    print("train1:", open("iwslt2016/segmented/train.de.bpe",'r').readline())
    print("train2:", open("iwslt2016/segmented/train.en.bpe", 'r').readline())
    print("eval1:", open("iwslt2016/segmented/eval.de.bpe", 'r').readline())
    print("eval2:", open("iwslt2016/segmented/eval.en.bpe", 'r').readline())
    print("test1:", open("iwslt2016/segmented/test.de.bpe", 'r').readline())

if __name__ == '__main__':
    hparams = Hparams()
    parser = hparams.parser
    hp = parser.parse_args()
    prepro(hp)
    logging.info("Done")

train中部分内容如下:

David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.
Wir werden Ihnen einige Geschichten über das Meer in Videoform erzählen.
Wir haben ein paar der unglaublichsten Aufnahmen der Titanic, die man je gesehen hat,, und wir werden Ihnen nichts davon zeigen.
Die Wahrheit ist, dass die Titanic – obwohl sie alle Kinokassenrekorde bricht – nicht gerade die aufregendste Geschichte vom Meer ist.
Ich denke, das Problem ist, dass wir das Meer für zu selbstverständlich halten.
Wenn man darüber nachdenkt, machen die Ozeane 75 % des Planeten aus.
Der Großteil der Erde ist Meerwasser.

train.en.bpe中部分内容如下:

▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o .
▁And ▁we ' re ▁going ▁to ▁tell ▁you ▁some ▁stories ▁from ▁the ▁sea ▁here ▁in ▁video .
▁We ' ve ▁got ▁some ▁of ▁the ▁most ▁incredible ▁video ▁of ▁Tit an ic ▁that ' s ▁ever ▁been ▁seen , ▁and ▁we ' re ▁not ▁going ▁to ▁show ▁you ▁any ▁of ▁it .
▁The ▁truth ▁of ▁the ▁matter ▁is ▁that ▁the ▁Tit an ic ▁-- ▁even ▁though ▁it ' s ▁break ing ▁all ▁sorts ▁of ▁box ▁office ▁record s ▁-- ▁it ' s ▁not ▁the ▁most ▁exciting ▁story ▁from ▁the ▁sea .
▁And ▁the ▁problem , ▁I ▁think , ▁is ▁that ▁we ▁take ▁the ▁ocean ▁for ▁gr anted .
▁When ▁you ▁think ▁about ▁it , ▁the ▁oce ans ▁are ▁75 ▁percent ▁of ▁the ▁planet .
▁Most ▁of ▁the ▁planet ▁is ▁ocean ▁water .

bpe.vocab部分内容如下:

<pad>    0
<unk>    0
<s>    0
</s>    0
en    -0
er    -1
in    -2
▁t    -3
ch    -4
▁a    -5
▁d    -6
▁w    -7
▁s    -8
▁th    -9
nd    -10
ie    -11
es    -12

5、train.py

# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
'''
import tensorflow as tf

from model import Transformer
from tqdm import tqdm
from data_load import get_batch
from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu
import os
from hparams import Hparams
import math
import logging

logging.basicConfig(level=logging.INFO)


logging.info("# hparams")
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
save_hparams(hp, hp.logdir)

logging.info("# Prepare train/eval batches")
train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2,
                                             hp.maxlen1, hp.maxlen2,
                                             hp.vocab, hp.batch_size,
                                             shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2,
                                             100000, 100000,
                                             hp.vocab, hp.batch_size,
                                             shuffle=False)

# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
xs, ys = iter.get_next()

train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)
# y_hat = m.infer(xs, ys)

logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.info("Initializing from scratch")
        sess.run(tf.global_variables_initializer())
        save_variable_specs(os.path.join(hp.logdir, "specs"))
    else:
        saver.restore(sess, ckpt)

    summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)

    sess.run(train_init_op)
    total_steps = hp.num_epochs * num_train_batches
    _gs = sess.run(global_step)
    for i in tqdm(range(_gs, total_steps+1)):
        _, _gs, _summary = sess.run([train_op, global_step, train_summaries])
        epoch = math.ceil(_gs / num_train_batches)
        summary_writer.add_summary(_summary, _gs)

        if _gs and _gs % num_train_batches == 0:
            logging.info("epoch {} is done".format(epoch))
            _loss = sess.run(loss) # train loss

            logging.info("# test evaluation")
            _, _eval_summaries = sess.run([eval_init_op, eval_summaries])
            summary_writer.add_summary(_eval_summaries, _gs)

            logging.info("# get hypotheses")
            hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token)

            logging.info("# write results")
            model_output = "iwslt2016_E%02dL%.2f" % (epoch, _loss)
            if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir)
            translation = os.path.join(hp.evaldir, model_output)
            with open(translation, 'w') as fout:
                fout.write("n".join(hypotheses))

            logging.info("# calc bleu score and append it to translation")
            calc_bleu(hp.eval3, translation)

            logging.info("# save models")
            ckpt_name = os.path.join(hp.logdir, model_output)
            saver.save(sess, ckpt_name, global_step=_gs)
            logging.info("after training of {} epochs, {} has been saved.".format(epoch, ckpt_name))

            logging.info("# fall back to train mode")
            sess.run(train_init_op)
    summary_writer.close()


logging.info("Done")

我们一行行来看:

首先调用了hparams.py中的函数:

import argparse

class Hparams:
    parser = argparse.ArgumentParser()

    # prepro
    parser.add_argument('--vocab_size', default=32000, type=int)

    # train
    ## files
    parser.add_argument('--train1', default='iwslt2016/segmented/train.de.bpe',
                             help="german training segmented data")
    parser.add_argument('--train2', default='iwslt2016/segmented/train.en.bpe',
                             help="english training segmented data")
    parser.add_argument('--eval1', default='iwslt2016/segmented/eval.de.bpe',
                             help="german evaluation segmented data")
    parser.add_argument('--eval2', default='iwslt2016/segmented/eval.en.bpe',
                             help="english evaluation segmented data")
    parser.add_argument('--eval3', default='iwslt2016/prepro/eval.en',
                             help="english evaluation unsegmented data")

    ## vocabulary
    parser.add_argument('--vocab', default='iwslt2016/segmented/bpe.vocab',
                        help="vocabulary file path")

    # training scheme
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--eval_batch_size', default=128, type=int)

    parser.add_argument('--lr', default=0.0003, type=float, help="learning rate")
    parser.add_argument('--warmup_steps', default=4000, type=int)
    parser.add_argument('--logdir', default="log/1", help="log directory")
    parser.add_argument('--num_epochs', default=20, type=int)
    parser.add_argument('--evaldir', default="eval/1", help="evaluation dir")

    # model
    parser.add_argument('--d_model', default=512, type=int,
                        help="hidden dimension of encoder/decoder")
    parser.add_argument('--d_ff', default=2048, type=int,
                        help="hidden dimension of feedforward layer")
    parser.add_argument('--num_blocks', default=6, type=int,
                        help="number of encoder/decoder blocks")
    parser.add_argument('--num_heads', default=8, type=int,
                        help="number of attention heads")
    parser.add_argument('--maxlen1', default=100, type=int,
                        help="maximum length of a source sequence")
    parser.add_argument('--maxlen2', default=100, type=int,
                        help="maximum length of a target sequence")
    parser.add_argument('--dropout_rate', default=0.3, type=float)
    parser.add_argument('--smoothing', default=0.1, type=float,
                        help="label smoothing rate")

    # test
    parser.add_argument('--test1', default='iwslt2016/segmented/test.de.bpe',
                        help="german test segmented data")
    parser.add_argument('--test2', default='iwslt2016/prepro/test.en',
                        help="english test data")
    parser.add_argument('--ckpt', help="checkpoint file path")
    parser.add_argument('--test_batch_size', default=128, type=int)
    parser.add_argument('--testdir', default="test/1", help="test result dir")

主要是一些超参数的设置。

然后是data_load.py中用来加载数据集:

# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer

Note.
if safe, entities on the source side have the prefix 1, and the target side 2, for convenience.
For example, fpath1, fpath2 means source file path and target file path, respectively.
'''
import tensorflow as tf
from utils import calc_num_batches

def load_vocab(vocab_fpath):
    '''Loads vocabulary file and returns idx<->token maps
    vocab_fpath: string. vocabulary file path.
    Note that these are reserved
    0: <pad>, 1: <unk>, 2: <s>, 3: </s>

    Returns
    two dictionaries.
    '''
    vocab = [line.split()[0] for line in open(vocab_fpath, 'r').read().splitlines()]
    token2idx = {token: idx for idx, token in enumerate(vocab)}
    idx2token = {idx: token for idx, token in enumerate(vocab)}
    return token2idx, idx2token

def load_data(fpath1, fpath2, maxlen1, maxlen2):
    '''Loads source and target data and filters out too lengthy samples.
    fpath1: source file path. string.
    fpath2: target file path. string.
    maxlen1: source sent maximum length. scalar.
    maxlen2: target sent maximum length. scalar.

    Returns
    sents1: list of source sents
    sents2: list of target sents
    '''
    sents1, sents2 = [], []
    with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
        for sent1, sent2 in zip(f1, f2):
            if len(sent1.split()) + 1 > maxlen1: continue # 1: </s>
            if len(sent2.split()) + 1 > maxlen2: continue  # 1: </s>
            sents1.append(sent1.strip())
            sents2.append(sent2.strip())
    return sents1, sents2


def encode(inp, type, dict):
    '''Converts string to number. Used for `generator_fn`.
    inp: 1d byte array.
    type: "x" (source side) or "y" (target side)
    dict: token2idx dictionary

    Returns
    list of numbers
    '''
    inp_str = inp.decode("utf-8")
    if type=="x": tokens = inp_str.split() + ["</s>"]
    else: tokens = ["<s>"] + inp_str.split() + ["</s>"]

    x = [dict.get(t, dict["<unk>"]) for t in tokens]
    return x

def generator_fn(sents1, sents2, vocab_fpath):
    '''Generates training / evaluation data
    sents1: list of source sents
    sents2: list of target sents
    vocab_fpath: string. vocabulary file path.

    yields
    xs: tuple of
        x: list of source token ids in a sent
        x_seqlen: int. sequence length of x
        sent1: str. raw source (=input) sentence
    labels: tuple of
        decoder_input: decoder_input: list of encoded decoder inputs
        y: list of target token ids in a sent
        y_seqlen: int. sequence length of y
        sent2: str. target sentence
    '''
    token2idx, _ = load_vocab(vocab_fpath)
    for sent1, sent2 in zip(sents1, sents2):
        x = encode(sent1, "x", token2idx)
        y = encode(sent2, "y", token2idx)
        decoder_input, y = y[:-1], y[1:]

        x_seqlen, y_seqlen = len(x), len(y)
        yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)

def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
    '''Batchify data
    sents1: list of source sents
    sents2: list of target sents
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    xs: tuple of
        x: int32 tensor. (N, T1)
        x_seqlens: int32 tensor. (N,)
        sents1: str tensor. (N,)
    ys: tuple of
        decoder_input: int32 tensor. (N, T2)
        y: int32 tensor. (N, T2)
        y_seqlen: int32 tensor. (N, )
        sents2: str tensor. (N,)
    '''
    shapes = (([None], (), ()),
              ([None], [None], (), ()))
    types = ((tf.int32, tf.int32, tf.string),
             (tf.int32, tf.int32, tf.int32, tf.string))
    paddings = ((0, 0, ''),
                (0, 0, 0, ''))

    dataset = tf.data.Dataset.from_generator(
        generator_fn,
        output_shapes=shapes,
        output_types=types,
        args=(sents1, sents2, vocab_fpath))  # <- arguments for generator_fn. converted to np string arrays

    if shuffle: # for training
        dataset = dataset.shuffle(128*batch_size)

    dataset = dataset.repeat()  # iterate forever
    dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)

    return dataset

def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
    '''Gets training / evaluation mini-batches
    fpath1: source file path. string.
    fpath2: target file path. string.
    maxlen1: source sent maximum length. scalar.
    maxlen2: target sent maximum length. scalar.
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    batches
    num_batches: number of mini-batches
    num_samples
    '''
    sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
    batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
    num_batches = calc_num_batches(len(sents1), batch_size)
    return batches, num_batches, len(sents1)

6、看一下相关模型model.py

# -*- coding: utf-8 -*-
# /usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer

Transformer network
'''
import tensorflow as tf

from data_load import load_vocab
from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, label_smoothing, noam_scheme
from utils import convert_idx_to_token_tensor
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO)

class Transformer:
    '''
    xs: tuple of
        x: int32 tensor. (N, T1)
        x_seqlens: int32 tensor. (N,)
        sents1: str tensor. (N,)
    ys: tuple of
        decoder_input: int32 tensor. (N, T2)
        y: int32 tensor. (N, T2)
        y_seqlen: int32 tensor. (N, )
        sents2: str tensor. (N,)
    training: boolean.
    '''
    def __init__(self, hp):
        self.hp = hp
        self.token2idx, self.idx2token = load_vocab(hp.vocab)
        self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True)

    def encode(self, xs, training=True):
        '''
        Returns
        memory: encoder outputs. (N, T1, d_model)
        '''
        with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
            x, seqlens, sents1 = xs

            # src_masks
            src_masks = tf.math.equal(x, 0) # (N, T1)

            # embedding
            enc = tf.nn.embedding_lookup(self.embeddings, x) # (N, T1, d_model)
            enc *= self.hp.d_model**0.5 # scale

            enc += positional_encoding(enc, self.hp.maxlen1)
            enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training)

            ## Blocks
            for i in range(self.hp.num_blocks):
                with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
                    # self-attention
                    enc = multihead_attention(queries=enc,
                                              keys=enc,
                                              values=enc,
                                              key_masks=src_masks,
                                              num_heads=self.hp.num_heads,
                                              dropout_rate=self.hp.dropout_rate,
                                              training=training,
                                              causality=False)
                    # feed forward
                    enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
        memory = enc
        return memory, sents1, src_masks

    def decode(self, ys, memory, src_masks, training=True):
        '''
        memory: encoder outputs. (N, T1, d_model)
        src_masks: (N, T1)

        Returns
        logits: (N, T2, V). float32.
        y_hat: (N, T2). int32
        y: (N, T2). int32
        sents2: (N,). string.
        '''
        with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            decoder_inputs, y, seqlens, sents2 = ys

            # tgt_masks
            tgt_masks = tf.math.equal(decoder_inputs, 0)  # (N, T2)

            # embedding
            dec = tf.nn.embedding_lookup(self.embeddings, decoder_inputs)  # (N, T2, d_model)
            dec *= self.hp.d_model ** 0.5  # scale

            dec += positional_encoding(dec, self.hp.maxlen2)
            dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training)

            # Blocks
            for i in range(self.hp.num_blocks):
                with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
                    # Masked self-attention (Note that causality is True at this time)
                    dec = multihead_attention(queries=dec,
                                              keys=dec,
                                              values=dec,
                                              key_masks=tgt_masks,
                                              num_heads=self.hp.num_heads,
                                              dropout_rate=self.hp.dropout_rate,
                                              training=training,
                                              causality=True,
                                              scope="self_attention")

                    # Vanilla attention
                    dec = multihead_attention(queries=dec,
                                              keys=memory,
                                              values=memory,
                                              key_masks=src_masks,
                                              num_heads=self.hp.num_heads,
                                              dropout_rate=self.hp.dropout_rate,
                                              training=training,
                                              causality=False,
                                              scope="vanilla_attention")
                    ### Feed Forward
                    dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])

        # Final linear projection (embedding weights are shared)
        weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
        logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
        y_hat = tf.to_int32(tf.argmax(logits, axis=-1))

        return logits, y_hat, y, sents2

    def train(self, xs, ys):
        '''
        Returns
        loss: scalar.
        train_op: training operation
        global_step: scalar.
        summaries: training summary node
        '''
        # forward
        memory, sents1, src_masks = self.encode(xs)
        logits, preds, y, sents2 = self.decode(ys, memory, src_masks)

        # train scheme
        y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
        ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_)
        nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["<pad>"]))  # 0: <pad>
        loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)

        global_step = tf.train.get_or_create_global_step()
        lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
        optimizer = tf.train.AdamOptimizer(lr)
        train_op = optimizer.minimize(loss, global_step=global_step)

        tf.summary.scalar('lr', lr)
        tf.summary.scalar("loss", loss)
        tf.summary.scalar("global_step", global_step)

        summaries = tf.summary.merge_all()

        return loss, train_op, global_step, summaries

    def eval(self, xs, ys):
        '''Predicts autoregressively
        At inference, input ys is ignored.
        Returns
        y_hat: (N, T2)
        '''
        decoder_inputs, y, y_seqlen, sents2 = ys

        decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"]
        ys = (decoder_inputs, y, y_seqlen, sents2)

        memory, sents1, src_masks = self.encode(xs, False)

        logging.info("Inference graph is being built. Please be patient.")
        for _ in tqdm(range(self.hp.maxlen2)):
            logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False)
            if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break

            _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
            ys = (_decoder_inputs, y, y_seqlen, sents2)

        # monitor a random sample
        n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
        sent1 = sents1[n]
        pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
        sent2 = sents2[n]

        tf.summary.text("sent1", sent1)
        tf.summary.text("pred", pred)
        tf.summary.text("sent2", sent2)
        summaries = tf.summary.merge_all()

        return y_hat, summaries