conv2D.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #include "conv2D.hpp"
  2. namespace tf_lib {
  3. volatile int instances = 0;
  4. volatile int inParallel = 0;
  5. std::mutex printMu;
  6. Conv2DOp::Conv2DOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
  7. instance = instances++;
  8. OP_REQUIRES_OK(context, context->GetAttr("delay", &delay));
  9. };
  10. void Conv2DOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
  11. // Input tensor is of the following dimensions:
  12. // [ batch, in_rows, in_cols, in_depth ]
  13. const Tensor& input = context->input(0);
  14. ///const int32 *p = input.flat<int32>().data();
  15. // Input filter is of the following dimensions:
  16. // [ filter_rows, filter_cols, in_depth, out_depth]
  17. const Tensor& kernel = context->input(1);
  18. TensorShape kernel_shape = kernel.shape();
  19. TensorShape input_shape = input.shape();
  20. int batchSize = input_shape.dim_size(0);
  21. int channels = input_shape.dim_size(3);
  22. int filters = kernel_shape.dim_size(3);
  23. TensorShape output_shape;
  24. const int32 dims[] = {batchSize, outputSize, outputSize, channels * filters};
  25. TensorShapeUtils::MakeShape(dims, 4, &output_shape);
  26. output_shape.set_dim(0, batchSize);
  27. output_shape.set_dim(1, outputSize);
  28. output_shape.set_dim(2, outputSize);
  29. output_shape.set_dim(3, channels * filters);
  30. //printMu.lock();
  31. //std::cout << output_shape.DebugString() << std::endl;
  32. //printMu.unlock();
  33. // Output tensor is of the following dimensions:
  34. // [ in_batch, out_rows, out_cols, out_depth ]
  35. Tensor* output = nullptr;
  36. OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
  37. auto input_tensor = input.tensor<int32, 4>();
  38. auto output_tensor = output->tensor<int32, 4>();
  39. std::shared_ptr<JobList> jobs(new JobList(Module::dummyModule, batchSize * channels * filters));
  40. for(int sample=0; sample<batchSize; sample++) {
  41. for(int channel=0; channel<channels; channel++) {
  42. for(int filter=0; filter<filters; filter++) {
  43. std::shared_ptr<Job> &job = jobs->getJob(sample * channels * filters + channel * filters + filter);
  44. for(int x=0; x<outputSize; x++) {
  45. for(int y=0; y<outputSize; y++) {
  46. job->setPayload(x*outputSize + y, input_tensor(sample, x, y, channel));
  47. }
  48. }
  49. }
  50. }
  51. }
  52. jobs->setDoneCallback([output_tensor, &jobs, done]{
  53. output_tensor(0) = jobs->getJob(0)->getResponsePayload(0);
  54. done();
  55. });
  56. connectionManager.sendJobListAsync(jobs);
  57. }
  58. static Status MatMulGradHelper(FunctionDef* g, const string& opname,
  59. const string& attr_adj_x,
  60. const string& attr_adj_y, const string& x0,
  61. bool ax0, const string& x1, bool ax1,
  62. const string& y0, bool ay0, const string& y1,
  63. bool ay1) {
  64. // The final outputs are "dx" and "dy". If we're broadcasting compute
  65. // intermediate nodes for now.
  66. std::vector<FDH::Node> nodes = {
  67. {{("dx")},
  68. opname,
  69. {x0, x1},
  70. {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
  71. {{("dy")},
  72. opname,
  73. {y0, y1},
  74. {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
  75. };
  76. *g = FDH::Define(
  77. // Arg defs
  78. {"x: T", "y: T", "dz: T"},
  79. // Ret val defs
  80. {"dx: T", "dy: T"},
  81. // Attr defs
  82. {{"T: {half, float, double}"}},
  83. // Nodes
  84. nodes);
  85. return Status::OK();
  86. }
  87. Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
  88. const string opname = "MyMatMul";
  89. const string attr_adj_x = "transpose_a";
  90. const string attr_adj_y = "transpose_b";
  91. DataType T;
  92. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
  93. if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
  94. return errors::Unimplemented(
  95. "MatMul gradient for complex is not supported yet.");
  96. }
  97. bool ta;
  98. bool tb;
  99. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
  100. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
  101. if (!ta && !tb) {
  102. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  103. true, "x", true, "dz", false);
  104. }
  105. if (!ta && tb) {
  106. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  107. false, "dz", true, "x", false);
  108. }
  109. if (ta && !tb) {
  110. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
  111. true, "x", false, "dz", false);
  112. }
  113. CHECK(ta && tb);
  114. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
  115. true, "dz", true, "x", true);
  116. }
  117. }