Python实战——ESIM 模型搭建(keras版)

时间:2022-07-24
本文章向大家介绍Python实战——ESIM 模型搭建(keras版),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

ESIM 原理笔记见:论文笔记&翻译——Enhanced LSTM for Natural Language Inference(ESIM)

ESIM主要分为三部分:input encodinglocal inference modelinginference composition。如上图所示,ESIM 是左边一部分, 如下图所示

三部分简要代码如下:

1. input encoding

1.1 原理

bar{a}_i = BiLSTM(a, i), forall i in [1, 2, ..., l_a]
bar{b}_j = BiLSTM(b, j), forall j in [1, 2, ..., l_b]

1.2 实现

i1 = Input(shape=(SentenceLen,), dtype='float32')
i2 = Input(shape=(SentenceLen,), dtype='float32')

x1 = Embedding([CONFIG])(i1)
x2 = Embedding([CONFIG])(i2)

x1 = Bidirectional(LSTM(300, return_sequences=True))(x1)
x2 = Bidirectional(LSTM(300, return_sequences=True))(x2)

2. local inference modeling

2.1 原理

hat{a_i} = sum_{j=1}^{l_b} frac{exp{e_{ij}}}{sum_{k=1}^{l_b} exp(e_{ik})} bar{b}, forall i in [1, 2, ..., l_a]
hat{b_j} = sum_{i=1}^{l_a} frac{exp{e_{ij}}}{sum_{k=1}^{l_a} exp(e_{kj})} bar{a}, forall j in [1, 2, ..., l_b]

2.2 实现

e = Dot(axes=2)([x1, x2])
e1 = Softmax(axis=2)(e)
e2 = Softmax(axis=1)(e)
e1 = Lambda(K.expand_dims, arguments={'axis' : 3})(e1)
e2 = Lambda(K.expand_dims, arguments={'axis' : 3})(e2)

_x1 = Lambda(K.expand_dims, arguments={'axis' : 1})(x2)
_x1 = Multiply()([e1, _x1])
_x1 = Lambda(K.sum, arguments={'axis' : 2})(_x1)
_x2 = Lambda(K.expand_dims, arguments={'axis' : 2})(x1)
_x2 = Multiply()([e2, _x2])
_x2 = Lambda(K.sum, arguments={'axis' : 1})(_x2)

3. inference composition

3.1 原理

m_a = [bar{a}; hat{a}; bar{a} - hat{a}; bar{a} odot hat{a}]
m_b = [bar{b}; hat{b}; bar{b} - hat{b}; bar{b} odot hat{b}]
v_{a,i} = BiLSTM(m_a, i)
v_{b,j} = BiLSTM(m_b, j)
v_{a,ave} = sum_{i=1}^{l_a} frac{v_{a,i}}{l_a}
v_{a,max} = max_{i=1}^{l_a} v_{a,i}
v_{b,ave} = sum_{j=1}^{l_b} frac{v_{b,j}}{l_b}
v_{b,max} = max_{j=1}^{l_b} v_{b,j}
v = [v_{a,ave}; v_{a,max}; v_{b,ave}; v_{b,max} ]

3.2 实现

m1 = Concatenate()([x1, _x1, Subtract()([x1, _x1]), Multiply()([x1, _x1])])
m2 = Concatenate()([x2, _x2, Subtract()([x2, _x2]), Multiply()([x2, _x2])])

y1 = Bidirectional(LSTM(300, return_sequences=True))(m1)
y2 = Bidirectional(LSTM(300, return_sequences=True))(m2)

mx1 = Lambda(K.max, arguments={'axis' : 1})(y1)
av1 = Lambda(K.mean, arguments={'axis' : 1})(y1)
mx2 = Lambda(K.max, arguments={'axis' : 1})(y2)
av2 = Lambda(K.mean, arguments={'axis' : 1})(y2)

y = Concatenate()([av1, mx1, av2, mx2])
y = Dense(1024, activation='tanh')(y)
y = Dropout(0.5)(y)
y = Dense(1024, activation='tanh')(y)
y = Dropout(0.5)(y)
y = Dense(2, activation='softmax')(y)