A word on Tensorflow condition operation tf.cond




Sunday, January 21, 2018
$tf.cond$ is used to define if-statement in tensorflow operation graph. It has the following signature:

cond(
    pred,
    true\_fn=None,
    false\_fn=None)
An important point that needs to be noted is that $tf.cond$ evaluates both $true\_fn$ and $false\_fn$ first and then depend on $pred$ it returns one of them. This is also mentioned on Tensorflow API doc but it is not clear at first read. For example in following piece of code the value of z is 6 (summation of x and y) although we x > 1 is False. But since both operation is evaluation, $z=tf.add(x,y)$ is also calculated and set the result of z to 6.

import tensorflow as tf
x = tf.constant(1)
y = tf.constant(5)
z = tf.get_variable('z', shape=[1])
z = tf.add(x, y)
res = tf.cond(x > 1, lambda: tf.multiply(x, z), lambda: y)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run([res, z])
\[ ------ output: [5, 6] ------ \] If one wants to only execute one of the branches based on a condition, the branch needs to be defined inside the $pred$. It is described in detail in this stackoverflow post.
 

Favorite Quotes

"I have never thought of writing for reputation and honor. What I have in my heart must out; that is the reason why I compose." --Beethoven

"All models are wrong, but some are useful." --George Box

Copyright © 2015 • Hamed's Ensemble Blogging