conv2D.py 388 B

12345678910111213
  1. import tensorflow as tf
  2. from tensorflow.keras import layers
  3. from .. import load_op
  4. class Conv2D(layers.Layer):
  5. def __init__(self, kernel):
  6. super(Conv2D, self).__init__()
  7. self.kernel = kernel
  8. def call(self, inputs):
  9. ints = tf.dtypes.cast(inputs, dtype=tf.int32)
  10. outs = load_op.op_lib.MyConv2D(input=ints, filter=ints)
  11. return tf.dtypes.cast(outs, dtype=tf.float32)