| import tensorflow as tf | |
| def masked_loss(label, pred): | |
| mask = label != 0 | |
| loss_object = tf.keras.losses.SparseCategoricalCrossentropy( | |
| from_logits=True, reduction='none') | |
| loss = loss_object(label, pred) | |
| mask = tf.cast(mask, dtype=loss.dtype) | |
| loss *= mask | |
| loss = tf.reduce_sum(loss)/tf.reduce_sum(mask) | |
| return loss | |
| def masked_accuracy(label, pred): | |
| pred = tf.argmax(pred, axis=2) | |
| label = tf.cast(label, pred.dtype) | |
| match = label == pred | |
| mask = label != 0 | |
| match = match & mask | |
| match = tf.cast(match, dtype=tf.float32) | |
| mask = tf.cast(mask, dtype=tf.float32) | |
| return tf.reduce_sum(match)/tf.reduce_sum(mask) |