tensorflow版本的tansformer训练IWSLT数据集
时间:2022-07-23
本文章向大家介绍tensorflow版本的tansformer训练IWSLT数据集,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
代码来源:https://github.com/Kyubyong/transformer
1、git clone https://github.com/Kyubyong/transformer.git
2、pip install sentencepiece
3、下载数据集
进入到tansformer目录下,输入:sh download.sh
运行成功之后,会有这么一些文件:
de-en.de.xml中内容大致是这个样子的:
<?xml version="1.0" encoding="UTF-8"?>
<mteval>
<srcset setid="iwslt2016-dev2010" srclang="german">
<doc docid="69" genre="lectures">
<url>http://www.ted.com/talks/lang/de/wade_davis_on_endangered_cultures.html</url>
<description>Mit atemberaubenden Fotos und Geschichten feiert der National Geographic- Forschungsreisende Wade Davis, die außergewöhnliche Vielfalt der Ureinwohner der Welt, welche in alarmierender Anzahl von unserem Planeten verschwinden.</description>
<keywords>anthropology,culture,environment,film,global issues,language,photography</keywords>
<talkid>69</talkid>
<title>Wade Davis über gefährdete Kulturen</title>
<seg id="1"> Wissen Sie, eines der großen Vernügen beim Reisen und eine der Freuden bei der ethnographischen Forschung ist, gemeinsam mit den Menschen zu leben, die sich noch an die alten Tage erinnern können. Die ihre Vergangenheit noch immer im Wind spüren, sie auf vom Regen geglätteten Steinen berühren, sie in den bitteren Blättern der Pflanzen schmecken. </seg>
<seg id="2"> Einfach das Wissen, dass Jaguar-Schamanen noch immer jenseits der Milchstraße reisen oder die Bedeutung der Mythen der Ältesten der Inuit noch voller Bedeutung sind, oder dass im Himalaya die Buddhisten noch immer den Atem des Dharma verfolgen, bedeutet, sich die zentrale Offenbarung der Anthropologie ins Gedächtnis zu rufen, das ist der Gedanke, dass die Welt, in der wir leben, nicht in einem absoluten Sinn existiert, sondern nur als ein Modell der Realität, als eine Folge einer Gruppe von bestimmten Möglichkeiten der Anpassung die unsere Ahnen, wenngleich erfolgreich, vor vielen Generationen wählten. </seg>
<seg id="3"> Und natürlich teilen wir alle dieselben Anpassungsnotwendigkeiten. </seg>
<seg id="4"> Wir werden alle geboren. Wir bringen Kinder zur Welt. </seg>
<seg id="5"> Wir durchlaufen Initiationsrituale. </seg>
<seg id="6"> Wir müssen uns mit der unaufhaltsamen Trennung durch den Tod auseinandersetzen und somit sollte es uns nicht überraschen, dass wir alle singen, tanzen und und Kunst hervorbringen. </seg>
<seg id="7"> Aber interessant ist der einzigartige Tonfall des Liedes, der Rhythmus des Tanzes in jeder Kultur. </seg>
<seg id="8"> Dabei spielt es keine Rolle, ob es sich um die Penan in den Wäldern von Borneo handelt, oder die Voodoo-Akolythen in Haiti, oder die Krieger in der Kaisut-Wüste von Nordkenia, die Curanderos in den Anden, oder eine Karawanserei mitten in der Sahara. Dies ist zufällig der Kollege, mit dem ich vor einem Monat in die Wüste gereist bin. Oder selbst ein Yak-Hirte an den Hängen des Qomolangma, Everest, der Gottmutter der Welt. </seg>
<seg id="9"> All diese Menschen lehren uns, dass es noch andere Existenzmöglichkeiten, andere Denkweisen, andere Wege zur Orientierung auf der Erde gibt. </seg>
<seg id="10"> Und das ist eine Vorstellung, die, wenn man darüber nachdenkt, einen nur mit Hoffnung erfüllen kann. </seg>
<seg id="11"> Zusammen bilden die unzähligen Kulturen der Welt ein Netz aus spirituellem und kulturellem Leben, das die Erde umhüllt und für das Wohl der Erde genauso wichtig ist, wie das biologische Lebensnetz, das man als Biosphäre kennt. </seg>
<seg id="12"> Man kann sich dieses kulturelle Lebensnetz als eine Ethnosphäre vorstellen. Ethnosphäre kann dabei als die Gesamtsumme aller Gedanken und Träume, Mythen Ideen, Inspirationen und Intuitionen, die von der menschlichen Vorstellungskraft seit den Anfängen des Bewusstseins hervorgebracht wurden, definiert werden. </seg>
<seg id="13"> Die Ethnosphäre ist das großartige Vermächtnis der Menschheit. </seg>
<seg id="14"> Sie ist das Symbol all dessen, was wir sind und wozu wir als erstaunlich wissbegierige Spezies fähig sind. </seg>
<seg id="15"> Und genauso wie die Biosphäre stark abgetragen wurde, geschah dies mit der Ethnosphäre -- nur mit noch größerer Geschwindigkeit. </seg>
<seg id="16"> Kein Biologe würde zum Beispiel wagen zu behaupten, dass 50% oder mehr aller Arten kurz vor dem Aussterben sind, da es einfach nicht stimmt. Und doch, dieses -- das apokalyptischste Szenarium auf dem Gebiet der biologischen Vielfalt -- entspricht kaum dem, was uns als optimistischstes Szenarium auf dem Gebiet der kulturellen Vielfalt bekannt ist. </seg>
<seg id="17"> Und der entscheidende Indikator dafür ist das Aussterben der Sprachen. </seg>
4、创建训练集、验证集、测试集
python prepro.py --vocab_size 8000
部分运行结果:
trainer_interface.cc(615) LOG(INFO) Saving model: iwslt2016/segmented/bpe.model
trainer_interface.cc(626) LOG(INFO) Saving vocabs: iwslt2016/segmented/bpe.vocab
INFO:root:# Load trained bpe model
INFO:root:# Segment
INFO:root:Let's see how segmented data look like
train1: ▁David ▁G all o : ▁Das ▁ist ▁Bill ▁L ange . ▁Ich ▁bin ▁Da ve ▁G all o .
train2: ▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o .
eval1: ▁Als ▁ich ▁11 ▁Jahre ▁alt ▁war , ▁wurde ▁ich ▁eines ▁Morgen s ▁von ▁den ▁Kl ängen ▁h eller ▁Freude ▁ge we ckt .
eval2: ▁When ▁I ▁was ▁11 , ▁I ▁remember ▁w aking ▁up ▁one ▁morning ▁to ▁the ▁sound ▁of ▁j oy ▁in ▁my ▁house .
test1: ▁Als ▁ich ▁in ▁meinen ▁20 ern ▁war , ▁hatte ▁ich ▁meine ▁erste ▁Psych other ap ie - P at ient in .
INFO:root:Done
运行之后会有:
prepro.py中的内容如下:
# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer.
Preprocess the iwslt 2016 datasets.
'''
import os
import errno
import sentencepiece as spm
import re
from hparams import Hparams
import logging
logging.basicConfig(level=logging.INFO)
def prepro(hp):
"""Load raw data -> Preprocessing -> Segmenting with sentencepice
hp: hyperparams. argparse.
"""
logging.info("# Check if raw files exist")
train1 = "iwslt2016/de-en/train.tags.de-en.de"
train2 = "iwslt2016/de-en/train.tags.de-en.en"
eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml"
eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml"
test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml"
test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml"
for f in (train1, train2, eval1, eval2, test1, test2):
if not os.path.isfile(f):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f)
logging.info("# Preprocessing")
# train
_prepro = lambda x: [line.strip() for line in open(x, 'r').read().split("n")
if not line.startswith("<")]
prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2)
assert len(prepro_train1)==len(prepro_train2), "Check if train source and target files match."
# eval
_prepro = lambda x: [re.sub("<[^>]+>", "", line).strip()
for line in open(x, 'r').read().split("n")
if line.startswith("<seg id")]
prepro_eval1, prepro_eval2 = _prepro(eval1), _prepro(eval2)
assert len(prepro_eval1) == len(prepro_eval2), "Check if eval source and target files match."
# test
prepro_test1, prepro_test2 = _prepro(test1), _prepro(test2)
assert len(prepro_test1) == len(prepro_test2), "Check if test source and target files match."
logging.info("Let's see how preprocessed data look like")
logging.info("prepro_train1:", prepro_train1[0])
logging.info("prepro_train2:", prepro_train2[0])
logging.info("prepro_eval1:", prepro_eval1[0])
logging.info("prepro_eval2:", prepro_eval2[0])
logging.info("prepro_test1:", prepro_test1[0])
logging.info("prepro_test2:", prepro_test2[0])
logging.info("# write preprocessed files to disk")
os.makedirs("iwslt2016/prepro", exist_ok=True)
def _write(sents, fname):
with open(fname, 'w') as fout:
fout.write("n".join(sents))
_write(prepro_train1, "iwslt2016/prepro/train.de")
_write(prepro_train2, "iwslt2016/prepro/train.en")
_write(prepro_train1+prepro_train2, "iwslt2016/prepro/train")
_write(prepro_eval1, "iwslt2016/prepro/eval.de")
_write(prepro_eval2, "iwslt2016/prepro/eval.en")
_write(prepro_test1, "iwslt2016/prepro/test.de")
_write(prepro_test2, "iwslt2016/prepro/test.en")
logging.info("# Train a joint BPE model with sentencepiece")
os.makedirs("iwslt2016/segmented", exist_ok=True)
train = '--input=iwslt2016/prepro/train --pad_id=0 --unk_id=1
--bos_id=2 --eos_id=3
--model_prefix=iwslt2016/segmented/bpe --vocab_size={}
--model_type=bpe'.format(hp.vocab_size)
spm.SentencePieceTrainer.Train(train)
logging.info("# Load trained bpe model")
sp = spm.SentencePieceProcessor()
sp.Load("iwslt2016/segmented/bpe.model")
logging.info("# Segment")
def _segment_and_write(sents, fname):
with open(fname, "w") as fout:
for sent in sents:
pieces = sp.EncodeAsPieces(sent)
fout.write(" ".join(pieces) + "n")
_segment_and_write(prepro_train1, "iwslt2016/segmented/train.de.bpe")
_segment_and_write(prepro_train2, "iwslt2016/segmented/train.en.bpe")
_segment_and_write(prepro_eval1, "iwslt2016/segmented/eval.de.bpe")
_segment_and_write(prepro_eval2, "iwslt2016/segmented/eval.en.bpe")
_segment_and_write(prepro_test1, "iwslt2016/segmented/test.de.bpe")
logging.info("Let's see how segmented data look like")
print("train1:", open("iwslt2016/segmented/train.de.bpe",'r').readline())
print("train2:", open("iwslt2016/segmented/train.en.bpe", 'r').readline())
print("eval1:", open("iwslt2016/segmented/eval.de.bpe", 'r').readline())
print("eval2:", open("iwslt2016/segmented/eval.en.bpe", 'r').readline())
print("test1:", open("iwslt2016/segmented/test.de.bpe", 'r').readline())
if __name__ == '__main__':
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
prepro(hp)
logging.info("Done")
train中部分内容如下:
David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.
Wir werden Ihnen einige Geschichten über das Meer in Videoform erzählen.
Wir haben ein paar der unglaublichsten Aufnahmen der Titanic, die man je gesehen hat,, und wir werden Ihnen nichts davon zeigen.
Die Wahrheit ist, dass die Titanic – obwohl sie alle Kinokassenrekorde bricht – nicht gerade die aufregendste Geschichte vom Meer ist.
Ich denke, das Problem ist, dass wir das Meer für zu selbstverständlich halten.
Wenn man darüber nachdenkt, machen die Ozeane 75 % des Planeten aus.
Der Großteil der Erde ist Meerwasser.
train.en.bpe中部分内容如下:
▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o .
▁And ▁we ' re ▁going ▁to ▁tell ▁you ▁some ▁stories ▁from ▁the ▁sea ▁here ▁in ▁video .
▁We ' ve ▁got ▁some ▁of ▁the ▁most ▁incredible ▁video ▁of ▁Tit an ic ▁that ' s ▁ever ▁been ▁seen , ▁and ▁we ' re ▁not ▁going ▁to ▁show ▁you ▁any ▁of ▁it .
▁The ▁truth ▁of ▁the ▁matter ▁is ▁that ▁the ▁Tit an ic ▁-- ▁even ▁though ▁it ' s ▁break ing ▁all ▁sorts ▁of ▁box ▁office ▁record s ▁-- ▁it ' s ▁not ▁the ▁most ▁exciting ▁story ▁from ▁the ▁sea .
▁And ▁the ▁problem , ▁I ▁think , ▁is ▁that ▁we ▁take ▁the ▁ocean ▁for ▁gr anted .
▁When ▁you ▁think ▁about ▁it , ▁the ▁oce ans ▁are ▁75 ▁percent ▁of ▁the ▁planet .
▁Most ▁of ▁the ▁planet ▁is ▁ocean ▁water .
bpe.vocab部分内容如下:
<pad> 0
<unk> 0
<s> 0
</s> 0
en -0
er -1
in -2
▁t -3
ch -4
▁a -5
▁d -6
▁w -7
▁s -8
▁th -9
nd -10
ie -11
es -12
5、train.py
# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
'''
import tensorflow as tf
from model import Transformer
from tqdm import tqdm
from data_load import get_batch
from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu
import os
from hparams import Hparams
import math
import logging
logging.basicConfig(level=logging.INFO)
logging.info("# hparams")
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
save_hparams(hp, hp.logdir)
logging.info("# Prepare train/eval batches")
train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2,
hp.maxlen1, hp.maxlen2,
hp.vocab, hp.batch_size,
shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2,
100000, 100000,
hp.vocab, hp.batch_size,
shuffle=False)
# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
xs, ys = iter.get_next()
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)
logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)
# y_hat = m.infer(xs, ys)
logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session() as sess:
ckpt = tf.train.latest_checkpoint(hp.logdir)
if ckpt is None:
logging.info("Initializing from scratch")
sess.run(tf.global_variables_initializer())
save_variable_specs(os.path.join(hp.logdir, "specs"))
else:
saver.restore(sess, ckpt)
summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
sess.run(train_init_op)
total_steps = hp.num_epochs * num_train_batches
_gs = sess.run(global_step)
for i in tqdm(range(_gs, total_steps+1)):
_, _gs, _summary = sess.run([train_op, global_step, train_summaries])
epoch = math.ceil(_gs / num_train_batches)
summary_writer.add_summary(_summary, _gs)
if _gs and _gs % num_train_batches == 0:
logging.info("epoch {} is done".format(epoch))
_loss = sess.run(loss) # train loss
logging.info("# test evaluation")
_, _eval_summaries = sess.run([eval_init_op, eval_summaries])
summary_writer.add_summary(_eval_summaries, _gs)
logging.info("# get hypotheses")
hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token)
logging.info("# write results")
model_output = "iwslt2016_E%02dL%.2f" % (epoch, _loss)
if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir)
translation = os.path.join(hp.evaldir, model_output)
with open(translation, 'w') as fout:
fout.write("n".join(hypotheses))
logging.info("# calc bleu score and append it to translation")
calc_bleu(hp.eval3, translation)
logging.info("# save models")
ckpt_name = os.path.join(hp.logdir, model_output)
saver.save(sess, ckpt_name, global_step=_gs)
logging.info("after training of {} epochs, {} has been saved.".format(epoch, ckpt_name))
logging.info("# fall back to train mode")
sess.run(train_init_op)
summary_writer.close()
logging.info("Done")
我们一行行来看:
首先调用了hparams.py中的函数:
import argparse
class Hparams:
parser = argparse.ArgumentParser()
# prepro
parser.add_argument('--vocab_size', default=32000, type=int)
# train
## files
parser.add_argument('--train1', default='iwslt2016/segmented/train.de.bpe',
help="german training segmented data")
parser.add_argument('--train2', default='iwslt2016/segmented/train.en.bpe',
help="english training segmented data")
parser.add_argument('--eval1', default='iwslt2016/segmented/eval.de.bpe',
help="german evaluation segmented data")
parser.add_argument('--eval2', default='iwslt2016/segmented/eval.en.bpe',
help="english evaluation segmented data")
parser.add_argument('--eval3', default='iwslt2016/prepro/eval.en',
help="english evaluation unsegmented data")
## vocabulary
parser.add_argument('--vocab', default='iwslt2016/segmented/bpe.vocab',
help="vocabulary file path")
# training scheme
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--eval_batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.0003, type=float, help="learning rate")
parser.add_argument('--warmup_steps', default=4000, type=int)
parser.add_argument('--logdir', default="log/1", help="log directory")
parser.add_argument('--num_epochs', default=20, type=int)
parser.add_argument('--evaldir', default="eval/1", help="evaluation dir")
# model
parser.add_argument('--d_model', default=512, type=int,
help="hidden dimension of encoder/decoder")
parser.add_argument('--d_ff', default=2048, type=int,
help="hidden dimension of feedforward layer")
parser.add_argument('--num_blocks', default=6, type=int,
help="number of encoder/decoder blocks")
parser.add_argument('--num_heads', default=8, type=int,
help="number of attention heads")
parser.add_argument('--maxlen1', default=100, type=int,
help="maximum length of a source sequence")
parser.add_argument('--maxlen2', default=100, type=int,
help="maximum length of a target sequence")
parser.add_argument('--dropout_rate', default=0.3, type=float)
parser.add_argument('--smoothing', default=0.1, type=float,
help="label smoothing rate")
# test
parser.add_argument('--test1', default='iwslt2016/segmented/test.de.bpe',
help="german test segmented data")
parser.add_argument('--test2', default='iwslt2016/prepro/test.en',
help="english test data")
parser.add_argument('--ckpt', help="checkpoint file path")
parser.add_argument('--test_batch_size', default=128, type=int)
parser.add_argument('--testdir', default="test/1", help="test result dir")
主要是一些超参数的设置。
然后是data_load.py中用来加载数据集:
# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
Note.
if safe, entities on the source side have the prefix 1, and the target side 2, for convenience.
For example, fpath1, fpath2 means source file path and target file path, respectively.
'''
import tensorflow as tf
from utils import calc_num_batches
def load_vocab(vocab_fpath):
'''Loads vocabulary file and returns idx<->token maps
vocab_fpath: string. vocabulary file path.
Note that these are reserved
0: <pad>, 1: <unk>, 2: <s>, 3: </s>
Returns
two dictionaries.
'''
vocab = [line.split()[0] for line in open(vocab_fpath, 'r').read().splitlines()]
token2idx = {token: idx for idx, token in enumerate(vocab)}
idx2token = {idx: token for idx, token in enumerate(vocab)}
return token2idx, idx2token
def load_data(fpath1, fpath2, maxlen1, maxlen2):
'''Loads source and target data and filters out too lengthy samples.
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
Returns
sents1: list of source sents
sents2: list of target sents
'''
sents1, sents2 = [], []
with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
for sent1, sent2 in zip(f1, f2):
if len(sent1.split()) + 1 > maxlen1: continue # 1: </s>
if len(sent2.split()) + 1 > maxlen2: continue # 1: </s>
sents1.append(sent1.strip())
sents2.append(sent2.strip())
return sents1, sents2
def encode(inp, type, dict):
'''Converts string to number. Used for `generator_fn`.
inp: 1d byte array.
type: "x" (source side) or "y" (target side)
dict: token2idx dictionary
Returns
list of numbers
'''
inp_str = inp.decode("utf-8")
if type=="x": tokens = inp_str.split() + ["</s>"]
else: tokens = ["<s>"] + inp_str.split() + ["</s>"]
x = [dict.get(t, dict["<unk>"]) for t in tokens]
return x
def generator_fn(sents1, sents2, vocab_fpath):
'''Generates training / evaluation data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
yields
xs: tuple of
x: list of source token ids in a sent
x_seqlen: int. sequence length of x
sent1: str. raw source (=input) sentence
labels: tuple of
decoder_input: decoder_input: list of encoded decoder inputs
y: list of target token ids in a sent
y_seqlen: int. sequence length of y
sent2: str. target sentence
'''
token2idx, _ = load_vocab(vocab_fpath)
for sent1, sent2 in zip(sents1, sents2):
x = encode(sent1, "x", token2idx)
y = encode(sent2, "y", token2idx)
decoder_input, y = y[:-1], y[1:]
x_seqlen, y_seqlen = len(x), len(y)
yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)
def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
'''Batchify data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
'''
shapes = (([None], (), ()),
([None], [None], (), ()))
types = ((tf.int32, tf.int32, tf.string),
(tf.int32, tf.int32, tf.int32, tf.string))
paddings = ((0, 0, ''),
(0, 0, 0, ''))
dataset = tf.data.Dataset.from_generator(
generator_fn,
output_shapes=shapes,
output_types=types,
args=(sents1, sents2, vocab_fpath)) # <- arguments for generator_fn. converted to np string arrays
if shuffle: # for training
dataset = dataset.shuffle(128*batch_size)
dataset = dataset.repeat() # iterate forever
dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)
return dataset
def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
'''Gets training / evaluation mini-batches
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
batches
num_batches: number of mini-batches
num_samples
'''
sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
num_batches = calc_num_batches(len(sents1), batch_size)
return batches, num_batches, len(sents1)
6、看一下相关模型model.py
# -*- coding: utf-8 -*-
# /usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
Transformer network
'''
import tensorflow as tf
from data_load import load_vocab
from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, label_smoothing, noam_scheme
from utils import convert_idx_to_token_tensor
from tqdm import tqdm
import logging
logging.basicConfig(level=logging.INFO)
class Transformer:
'''
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
training: boolean.
'''
def __init__(self, hp):
self.hp = hp
self.token2idx, self.idx2token = load_vocab(hp.vocab)
self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True)
def encode(self, xs, training=True):
'''
Returns
memory: encoder outputs. (N, T1, d_model)
'''
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
x, seqlens, sents1 = xs
# src_masks
src_masks = tf.math.equal(x, 0) # (N, T1)
# embedding
enc = tf.nn.embedding_lookup(self.embeddings, x) # (N, T1, d_model)
enc *= self.hp.d_model**0.5 # scale
enc += positional_encoding(enc, self.hp.maxlen1)
enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training)
## Blocks
for i in range(self.hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
# self-attention
enc = multihead_attention(queries=enc,
keys=enc,
values=enc,
key_masks=src_masks,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=False)
# feed forward
enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
memory = enc
return memory, sents1, src_masks
def decode(self, ys, memory, src_masks, training=True):
'''
memory: encoder outputs. (N, T1, d_model)
src_masks: (N, T1)
Returns
logits: (N, T2, V). float32.
y_hat: (N, T2). int32
y: (N, T2). int32
sents2: (N,). string.
'''
with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
decoder_inputs, y, seqlens, sents2 = ys
# tgt_masks
tgt_masks = tf.math.equal(decoder_inputs, 0) # (N, T2)
# embedding
dec = tf.nn.embedding_lookup(self.embeddings, decoder_inputs) # (N, T2, d_model)
dec *= self.hp.d_model ** 0.5 # scale
dec += positional_encoding(dec, self.hp.maxlen2)
dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training)
# Blocks
for i in range(self.hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
# Masked self-attention (Note that causality is True at this time)
dec = multihead_attention(queries=dec,
keys=dec,
values=dec,
key_masks=tgt_masks,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=True,
scope="self_attention")
# Vanilla attention
dec = multihead_attention(queries=dec,
keys=memory,
values=memory,
key_masks=src_masks,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=False,
scope="vanilla_attention")
### Feed Forward
dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
# Final linear projection (embedding weights are shared)
weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
return logits, y_hat, y, sents2
def train(self, xs, ys):
'''
Returns
loss: scalar.
train_op: training operation
global_step: scalar.
summaries: training summary node
'''
# forward
memory, sents1, src_masks = self.encode(xs)
logits, preds, y, sents2 = self.decode(ys, memory, src_masks)
# train scheme
y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_)
nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["<pad>"])) # 0: <pad>
loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
global_step = tf.train.get_or_create_global_step()
lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.minimize(loss, global_step=global_step)
tf.summary.scalar('lr', lr)
tf.summary.scalar("loss", loss)
tf.summary.scalar("global_step", global_step)
summaries = tf.summary.merge_all()
return loss, train_op, global_step, summaries
def eval(self, xs, ys):
'''Predicts autoregressively
At inference, input ys is ignored.
Returns
y_hat: (N, T2)
'''
decoder_inputs, y, y_seqlen, sents2 = ys
decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"]
ys = (decoder_inputs, y, y_seqlen, sents2)
memory, sents1, src_masks = self.encode(xs, False)
logging.info("Inference graph is being built. Please be patient.")
for _ in tqdm(range(self.hp.maxlen2)):
logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False)
if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break
_decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
ys = (_decoder_inputs, y, y_seqlen, sents2)
# monitor a random sample
n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
sent1 = sents1[n]
pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
sent2 = sents2[n]
tf.summary.text("sent1", sent1)
tf.summary.text("pred", pred)
tf.summary.text("sent2", sent2)
summaries = tf.summary.merge_all()
return y_hat, summaries
- 对事件委托绑定click的事件的解绑
- 免杀后门之MSF&Veil-Evasion的完美结合
- 【52ABP实战教程】0.3-- 从GitHub推送代码回VSTS实现双向同步
- css绝对定位如何在不同分辨率下的电脑正常显示定位位置?
- nvm安装node和npm,个人踩坑记录
- clang_intprt_t类型探究
- 学习zepto.js(Hello World)
- JS中函数声明与函数表达式的异同
- [技巧]看我如何通过Weeman+Ettercap拿下路由器管理权限
- 一分钟理清Vue-cli 代码构建步骤。
- 点击图片放大至原始图片大小
- 替代jquery1.9版本以前的toggle事件函数(开关)
- 总结CSS3新特性(Animation篇)
- Scrapy爬虫入门
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- RocketMQ学习六-消息存储
- swoole 实现 unixSocket 通信
- mybatis-plus一对多关联查询踩坑
- 深入Spring Security魔幻山谷-获取认证机制核心原理讲解
- 文本相似性的总结
- Java面试题总结之JDBC 和Hibernate
- Mac 下搭建 Clion + OpenCV4.x 的开发环境
- 超详细,Windows系统搭建Flink官方练习环境
- MySQL 覆盖索引与延迟关联
- Java面试题总结之数据结构、算法和计算机基础(刘小牛和丝音的爱情故事1)
- 在Java中什么时候才要考虑线程安全
- android功耗优化(2)--对齐唤醒
- Android 功耗(3)---高通功耗问题分析方法
- 搞定Java快速排序
- Android 功耗(4)---MTK平台待机功耗分析流程