zero_out.cc 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #include "tensorflow/core/framework/op.h"
  2. #include "tensorflow/core/framework/shape_inference.h"
  3. using namespace tensorflow;
  4. REGISTER_OP("ZeroOut")
  5. .Input("to_zero: int32")
  6. .Output("zeroed: int32")
  7. .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  8. c->set_output(0, c->input(0));
  9. return Status::OK();
  10. });
  11. #include "tensorflow/core/framework/op_kernel.h"
  12. using namespace tensorflow;
  13. class ZeroOutOp : public OpKernel {
  14. public:
  15. explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  16. void Compute(OpKernelContext* context) override {
  17. // Grab the input tensor
  18. const Tensor& input_tensor = context->input(0);
  19. auto input = input_tensor.flat<int32>();
  20. printf("call n: %d\n", n++);
  21. // Create an output tensor
  22. Tensor* output_tensor = NULL;
  23. OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
  24. &output_tensor));
  25. auto output_flat = output_tensor->flat<int32>();
  26. // Set all but the first element of the output tensor to 0.
  27. const int N = input.size();
  28. for (int i = 1; i < N; i++) {
  29. output_flat(i) = 0;
  30. }
  31. // Preserve the first input value if possible.
  32. if (N > 0) output_flat(0) = input(0);
  33. }
  34. int n = 0;
  35. };
  36. REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);