tf.nn.embedding_lookup 함수

tf.nn.embedding_lookup 함수

Example.

vocab = ['the','like','between','did','just','national','day','country','under','such','second']

emb = np.array([[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.044457, -0.49688, -0.17862],
   [0.36808, 0.20834, -0.22319, 0.046283, 0.20098, 0.27515, -0.77127, -0.76804],
   [0.7503, 0.71623, -0.27033, 0.20059, -0.17008, 0.68568, -0.061672, -0.054638],
   [0.042523, -0.21172, 0.044739, -0.19248, 0.26224, 0.0043991, -0.88195, 0.55184],
   [0.17698, 0.065221, 0.28548, -0.4243, 0.7499, -0.14892, -0.66786, 0.11788],
   [-1.1105, 0.94945, -0.17078, 0.93037, -0.2477, -0.70633, -0.8649, -0.56118],
   [0.11626, 0.53897, -0.39514, -0.26027, 0.57706, -0.79198, -0.88374, 0.30119],
   [-0.13531, 0.15485, -0.07309, 0.034013, -0.054457, -0.20541, -0.60086, -0.22407],
   [ 0.13721, -0.295, -0.05916, -0.59235, 0.02301, 0.21884, -0.34254, -0.70213],
   [ 0.61012, 0.33512, -0.53499, 0.36139, -0.39866, 0.70627, -0.18699, -0.77246 ],
   [ -0.29809, 0.28069, 0.087102, 0.54455, 0.70003, 0.44778, -0.72565, 0.62309 ]])

emb.shape
# (11, 8)
from collections import OrderedDict
tf_embedding = tf.constant(emb, dtype=tf.float32)
input_str = "like the country"
word_to_idx = OrderedDict({w:vocab.index(w) for w in input_str.split() if w in vocab})
print(word_to_idx)
OrderedDict([('like', 1), ('the', 0), ('country', 7)])

1차원 배열을 파라미터로 전달했을 경우 ([1,0,7] 전달)

>>> tf.nn.embedding_lookup(tf_embedding, list(word_to_idx.values())).eval(session=tf.Session())
array([[ 0.36807999,  0.20834   , -0.22318999,  0.046283  ,  0.20097999,
         0.27515   , -0.77126998, -0.76804   ],
       [ 0.41800001,  0.24968   , -0.41242   ,  0.1217    ,  0.34527001,
        -0.044457  , -0.49687999, -0.17862   ],
       [-0.13530999,  0.15485001, -0.07309   ,  0.034013  , -0.054457  ,
        -0.20541   , -0.60086   , -0.22407   ]], dtype=float32)
>>> tf.nn.embedding_lookup(tf_embedding, list(word_to_idx.values())).eval(session=tf.Session()).shape
(3, 8)

2차원 배열을 파라미터로 전달했을 경우 (Shape (2, 2) 전달)

>>> tf.nn.embedding_lookup(tf_embedding, [[0,2],[0,1]]).eval(session=tf.Session())
array([[[ 0.418   ,  0.24968 , -0.41242 ,  0.1217  ,  0.34527 ,
         -0.044457, -0.49688 , -0.17862 ],
        [ 0.7503  ,  0.71623 , -0.27033 ,  0.20059 , -0.17008 ,
          0.68568 , -0.061672, -0.054638]],

       [[ 0.418   ,  0.24968 , -0.41242 ,  0.1217  ,  0.34527 ,
         -0.044457, -0.49688 , -0.17862 ],
        [ 0.36808 ,  0.20834 , -0.22319 ,  0.046283,  0.20098 ,
          0.27515 , -0.77127 , -0.76804 ]]], dtype=float32)
>>> tf.nn.embedding_lookup(tf_embedding, [[0,2],[0,1]]).eval(session=tf.Session()).shape       
(2, 2, 8)

reference: https://stackoverflow.com/questions/34870614/what-does-tf-nn-embedding-lookup-function-do

Categories:

Updated:

Leave a comment