train.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import tensorflow as tf
  2. import tensorflow.keras as keras
  3. from tensorflow.keras import layers
  4. from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Dropout, Flatten, MaxPooling2D, Conv2D
  5. from tensorflow.keras.models import Model, Sequential
  6. from tensorflow.keras.datasets import mnist
  7. from tensorflow.keras.utils import plot_model, to_categorical
  8. import numpy as np
  9. from IPython import embed
  10. my_matmul_module = tf.load_op_library('./matMul.so')
  11. batch_size = 128
  12. num_classes = 10
  13. epochs = 1 # 12
  14. # input image dimensions
  15. img_rows, img_cols = 28, 28
  16. # the data, split between train and test sets
  17. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  18. x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
  19. x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
  20. input_shape = (img_rows, img_cols, 1)
  21. x_train = x_train.astype('float32')
  22. x_test = x_test.astype('float32')
  23. x_train /= 255
  24. x_test /= 255
  25. print('x_train shape:', x_train.shape)
  26. print(x_train.shape[0], 'train samples')
  27. print(x_test.shape[0], 'test samples')
  28. # convert class vectors to binary class matrices
  29. y_train = to_categorical(y_train, num_classes)
  30. y_test = to_categorical(y_test, num_classes)
  31. class Conv2DFPGA(layers.Layer):
  32. def __init__(self, kernel):
  33. super(Conv2DFPGA, self).__init__()
  34. self.kernel = kernel
  35. def call(self, inputs):
  36. ints = tf.dtypes.cast(inputs, dtype=tf.int32)
  37. outs = my_matmul_module.MyConv2D(input=ints, filter=ints)
  38. return tf.dtypes.cast(outs, dtype=tf.float32)
  39. class MyConv2D(layers.Conv2D):
  40. def __init__(self,
  41. filters,
  42. kernel_size,
  43. strides=(1, 1),
  44. padding='valid',
  45. data_format=None,
  46. dilation_rate=(1, 1),
  47. activation=None,
  48. use_bias=True,
  49. kernel_initializer='glorot_uniform',
  50. bias_initializer='zeros',
  51. kernel_regularizer=None,
  52. bias_regularizer=None,
  53. activity_regularizer=None,
  54. kernel_constraint=None,
  55. bias_constraint=None,
  56. **kwargs):
  57. super(MyConv2D, self).__init__(
  58. filters=filters,
  59. kernel_size=kernel_size,
  60. strides=strides,
  61. padding=padding,
  62. data_format=data_format,
  63. dilation_rate=dilation_rate,
  64. activation=activation,
  65. use_bias=use_bias,
  66. kernel_initializer=kernel_initializer,
  67. bias_initializer=bias_initializer,
  68. kernel_regularizer=kernel_regularizer,
  69. bias_regularizer=bias_regularizer,
  70. activity_regularizer=activity_regularizer,
  71. kernel_constraint=kernel_constraint,
  72. bias_constraint=bias_constraint,
  73. **kwargs)
  74. def call(self, inputs):
  75. #inputs.get_shape(),
  76. #filter_shape=self.kernel.shape,
  77. #dilation_rate=self.dilation_rate,
  78. #strides=self.strides,
  79. #padding=self._padding_op,
  80. #data_format=self._conv_op_data_format)
  81. #kernel.shape.ndims
  82. #inputs.get_shape().ndims
  83. if self.rank == 1 and inputs.get_shape(): #fpga restriction
  84. return my_matmul_module.MyConv2D(inputs, self.kernel)
  85. else:
  86. return super(MyConv2D, self).call(inputs)
  87. model = Sequential()
  88. model.add(MyConv2D(32, kernel_size=(3, 3),
  89. activation='relu',
  90. input_shape=input_shape))
  91. model.add(Conv2DFPGA([0,0]))
  92. model.add(Flatten())
  93. model.add(Dense(128, activation='relu'))
  94. model.add(Dropout(0.5))
  95. model.add(Dense(num_classes, activation='softmax'))
  96. model.compile(loss=keras.losses.categorical_crossentropy,
  97. optimizer=keras.optimizers.Adadelta(),
  98. metrics=['accuracy'])
  99. model.fit(x_train, y_train,
  100. batch_size=batch_size,
  101. epochs=epochs,
  102. verbose=1,
  103. validation_data=(x_test, y_test))
  104. score = model.evaluate(x_test, y_test, verbose=0)
  105. print('Test loss:', score[0])
  106. print('Test accuracy:', score[1])
  107. plot_model(model, to_file='model.png', expand_nested=True, show_shapes=True)