Ver código fonte

async c++ matrix

subDesTagesMitExtraKaese 5 anos atrás
pai
commit
f82d696d86
5 arquivos alterados com 71 adições e 39 exclusões
  1. BIN
      build/op_lib.so
  2. 13 6
      examples/train.py
  3. 2 16
      layers/conv2D.py
  4. 53 17
      src/conv2D.cpp
  5. 3 0
      src/conv2D.hpp

BIN
build/op_lib.so


+ 13 - 6
examples/train.py

@@ -40,13 +40,20 @@ y_train = to_categorical(y_train, num_classes)
 y_test = to_categorical(y_test, num_classes)
 
 a = layers.Input(dtype=tf.int32, shape=(28, 28, 1))
-b = Conv2DFPGA(32)(a)
-c = Conv2DFPGA(16)(a)
-d = Conv2DFPGA(2)(b)
-e = Conv2DFPGA(4)(c)
+b = Conv2DFPGA(2)(a)
+c = Conv2DFPGA(1)(a)
+d = Conv2DFPGA(1)(b)
+e = Conv2DFPGA(2)(c)
 
-x = layers.Add()([layers.Flatten()(d),layers.Flatten()(e)])
-z = layers.Dense(num_classes, activation='softmax')(x)
+print(a)
+print(b)
+print(c)
+print(d)
+print(e)
+
+x = layers.Add()([d,e])
+y = layers.Flatten()(x)
+z = layers.Dense(num_classes, activation='softmax')(y)
 
 model = Model(inputs=a, outputs=z)
 """

+ 2 - 16
layers/conv2D.py

@@ -36,19 +36,5 @@ class Conv2D(layers.Layer):
   def call(self, inputs):
 
     #out = tf.Tensor(tf.int32, shape=inputs.shape)
-
-    ch_inputs = tf.unstack(inputs, axis=3)#tf.dtypes.cast(inputs, dtype=tf.int32), axis=3)
-    ch_kernel = tf.unstack(tf.dtypes.cast(self.kernel, dtype=tf.int32), axis=2)
-
-    ch_outputs = [None] * len(ch_inputs)
-
-    for ch in range(len(ch_inputs)):
-      print(ch_inputs[ch], ch_kernel[ch])
-      ch_outputs[ch] = [None] * self.filters
-      kernel_2d = tf.unstack(ch_kernel[ch], axis=2)
-      for f in range(len(kernel_2d)):
-        ch_outputs[ch][f] = load_op.op_lib.MyConv2D(input=ch_inputs[ch], filter=kernel_2d[f], delay=(f+1)*100)
-      
-      ch_outputs[ch] = tf.stack(ch_outputs[ch], axis=2)
-    outs = tf.stack(ch_outputs, axis=2)
-    return outs #tf.dtypes.cast(outs, dtype=tf.float32)
+    intKernel = tf.cast(self.kernel, dtype=tf.int32)
+    return load_op.op_lib.MyConv2D(input=inputs, filter=intKernel, delay=1000*self.filters)

+ 53 - 17
src/conv2D.cpp

@@ -5,19 +5,39 @@
 
 volatile int instances = 0;
 volatile int inParallel = 0;
-std::mutex mu;
+std::mutex printMu;
 
-void delayThread(int ins, const char *name, int delay, std::function<void ()> done) {
-  mu.lock();
-  printf("parallel: %2d instance: %2d '%s' %dms sleep\n", ++inParallel, ins, name, delay);
-  mu.unlock();
+void Conv2DOp::delayThread(DoneCallback done) {
+  printMu.lock();
+  printf("parallel: %2d instance: %2d '%s' %dms sleep\n", ++inParallel, instance, name().c_str(), delay);
+  printMu.unlock();
   std::this_thread::sleep_for(milliseconds(delay));
-  mu.lock();
-  printf("parallel: %2d instance: %2d '%s' done\n", --inParallel, ins, name);
-  mu.unlock();
+  printMu.lock();
+  printf("parallel: %2d instance: %2d '%s' done\n", --inParallel, instance, name().c_str());
+  printMu.unlock();
   done();
 }
 
+void Conv2DOp::fpgaCall(const Tensor *input, const Tensor *kernel, Tensor *output, int sample, int channel, int filter) {
+    auto input_tensor = input->tensor<int32, 4>();
+    auto kernel_tensor = kernel->tensor<int32, 4>();
+    auto output_tensor = output->tensor<int32, 4>();
+    int size = 24;
+    
+    printMu.lock();
+    //printf(" sample: %3d, channel: %3d, filter: %3d\n", sample, channel, filter);
+    /*
+    for(int x=0; x<size; x++) {
+      for(int y=0; y<size; y++) {
+        printf("%c", input_tensor(sample, x, y, channel) > 0 ? '#' : ' ');
+      }
+      std::cout << std::endl;
+    }
+    std::cout << std::endl;
+    */
+    printMu.unlock();
+}
+
 Conv2DOp::Conv2DOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
   instance = instances++;
   OP_REQUIRES_OK(context, context->GetAttr("delay", &delay));
@@ -29,22 +49,38 @@ void Conv2DOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
   // [ batch, in_rows, in_cols, in_depth ]
   const Tensor& input = context->input(0);
 
+  ///const int32 *p = input.flat<int32>().data();
+
   // Input filter is of the following dimensions:
   // [ filter_rows, filter_cols, in_depth, out_depth]
-  const Tensor& filter = context->input(1);
-  TensorShape filterShape = filter.shape();
+  const Tensor& kernel = context->input(1);
+
+  TensorShape kernel_shape = kernel.shape();
+  TensorShape input_shape = input.shape();
+  TensorShape output_shape = input.shape();
+  
 
-  TensorShape out_shape = input.shape();
+  int batchSize = input_shape.dim_size(0);
+  int channels = input_shape.dim_size(3);
+  int filters = kernel_shape.dim_size(3);
+
+  output_shape.set_dim(1, 24);
+  output_shape.set_dim(2, 24);
+  output_shape.set_dim(3, channels * filters);
 
   // Output tensor is of the following dimensions:
   // [ in_batch, out_rows, out_cols, out_depth ]
   Tensor* output = nullptr;
-  OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
-
-  context->cancellation_manager();
-
-  std::async(std::launch::async, delayThread, instance, name().c_str(), delay, done);
-  
+  OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
+
+  for(int sample=0; sample<batchSize; sample++) {
+    for(int channel=0; channel<channels; channel++) {
+      for(int filter=0; filter<filters; filter++) {
+        std::async(std::launch::async, &Conv2DOp::fpgaCall, this, &input, &kernel, output, sample, channel, filter);
+      }
+    }
+  }
+  std::async(std::launch::async, &Conv2DOp::delayThread, this, done);
 }
 
 

+ 3 - 0
src/conv2D.hpp

@@ -25,5 +25,8 @@ class Conv2DOp : public AsyncOpKernel {
     int instance = -1;
     int delay = 1000;
 
+    void fpgaCall(const Tensor *input, const Tensor *kernel, Tensor *output, int sample, int channel, int filter);
+    void delayThread(DoneCallback done);
+
   //TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
 };