import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.keras import layers, initializers, regularizers, constraints

from .. import load_op

class Conv2D(layers.Layer):
  def __init__(self,
    filters = 1,
    kernel_initializer = 'glorot_uniform',
               kernel_regularizer=None,
               kernel_constraint=None,
    ):
    super(Conv2D, self).__init__()
    #int, dim of output space
    self.filters = filters
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.kernel_constraint = constraints.get(kernel_constraint)


  def build(self, input_shape):
    input_shape = tf.TensorShape(input_shape)
    self.input_channel = input_shape[3]
    kernel_shape = (5,)*2 + (self.input_channel, self.filters)

    self.kernel = self.add_weight(
        name='kernel',
        shape=kernel_shape,
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        trainable=True,
        dtype=self.dtype)

  def call(self, inputs):

    return load_op.op_lib.MyConv2D(input=inputs, filter=self.kernel)