TensorFlow op和op kernel的註冊

本文連結地址:

「連結」

Content:

生成ops的定義

生成op kernels的factory定義

回目錄

生成ops的定義

TensorFlow Library在載入的時候,其中so裡面的靜態變數會例項化。

以AddN這個op為例,在檔案math_ops。cc裡面的REGISTER_OP(“AddN”)其實是個靜態變數定義:

REGISTER_OP(“AddN”) 。Input(“inputs: N * T”) 。Output(“sum: T”) 。Attr(“N: int >= 1”) 。Attr(“T: {numbertype, variant}”) 。SetIsCommutative() 。SetIsAggregate() 。SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); ………… });

REGISTER_OP(“AddN”)展開後為:

static ::tensorflow::register_op::OpDefBuilderReceiver register_op165 \ __attribute__((unused)) = \ ::tensorflow::register_op::OpDefBuilderWrapper(“AddN”)

在檔案op。cc中有OpDefBuilderReceiver的構造方法:

namespace register_op {OpDefBuilderReceiver::OpDefBuilderReceiver( const OpDefBuilderWrapper& wrapper) { OpRegistry::Global()->Register( [wrapper](OpRegistrationData* op_reg_data) -> Status { return wrapper。builder()。Finalize(op_reg_data); });}}

透過檔案op。cc中的OpRegistry::Register呼叫OpRegistry::RegisterAlreadyLocked來使op_data_factory產生op的定義OpRegistrationData。

static OpRegistry* OpRegistry::Global() { static OpRegistry* global_op_registry = new OpRegistry; return global_op_registry;} void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { mutex_lock lock(mu_); if (initialized_) { TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); } else { deferred_。push_back(op_data_factory); }} mutable std::unordered_map registry_; Status OpRegistry::RegisterAlreadyLocked( const OpRegistrationDataFactory& op_data_factory) const { std::unique_ptr op_reg_data(new OpRegistrationData); Status s = op_data_factory(op_reg_data。get()); if (s。ok()) { s = ValidateOpDef(op_reg_data->op_def); if (s。ok() && !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def。name(), op_reg_data。get())) { s = errors::AlreadyExists(“Op with name ”, op_reg_data->op_def。name()); } } Status watcher_status = s; if (watcher_) { watcher_status = watcher_(s, op_reg_data->op_def); } if (s。ok()) { op_reg_data。release(); } else { op_reg_data。reset(); } return watcher_status;}

透過檔案op_def_builder。cc中的OpDefBuilder::Finalize來給OpRegistrationData賦值。

Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { std::vector errors = errors_; *op_reg_data = op_reg_data_; OpDef* op_def = &op_reg_data->op_def; for (StringPiece attr : attrs_) { FinalizeAttr(attr, op_def, &errors); } for (StringPiece input : inputs_) { FinalizeInputOrOutput(input, false, op_def, &errors); } for (StringPiece output : outputs_) { FinalizeInputOrOutput(output, true, op_def, &errors); } for (StringPiece control_output : control_outputs_) { FinalizeControlOutput(control_output, op_def, &errors); } FinalizeDoc(doc_, op_def, &errors); if (errors。empty()) return Status::OK(); return errors::InvalidArgument(absl::StrJoin(errors, “\n”));}

以後使用的時候就可以透過檔案op。cc中的OpRegistry::LookUp獲取到op type對應的定義。

Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { { tf_shared_lock l(mu_); if (initialized_) { if (const OpRegistrationData* res = gtl::FindWithDefault(registry_, op_type_name, nullptr)) { *op_reg_data = res; return Status::OK(); } } } return LookUpSlow(op_type_name, op_reg_data);}

生成op kernels的factory定義

同樣,也是在TensorFlow Library載入的時候。

還是以AddN這個op為例,在檔案aggregate_ops。cc裡面,關於AddNOp這個op_kernel相關程式碼如下:

template class AddNOp : public OpKernel { public: explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { if (!ctx->ValidateInputsAreSameShape(this)) return; ………… ctx->set_output(0, out); }}; #define REGISTER_ADDN(type, dev) \ REGISTER_KERNEL_BUILDER( \ Name(“AddN”)。Device(DEVICE_##dev)。TypeConstraint(“T”), \ AddNOp) #define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU) TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);REGISTER_ADDN_CPU(Variant); #undef REGISTER_ADDN_CPU

TF_CALL_NUMBER_TYPES宏展開後如下,它會對所以的資料型別都分別建立一個OpKernelRegistrar物件:

constexpr bool should_register_389__flag = \ true; \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__389__object( \ should_register_389__flag \ ? ::tensorflow::register_kernel::Name(“AddN”)。Device(DEVICE_CPU)。TypeConstraint<::tensorflow::int64>(“T”)。Build() \ : nullptr, \ “AddNOp”, \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { \ return new AddNOp(context); \ }); …………

op_kernel。h中有上面OpKernelRegistrar的建構函式定義:

::tensorflow::register_kernel::Name(“AddN”)。Device(DEVICE_CPU)。TypeConstraint<::tensorflow::int64>(“T”)。Build()對應KernelDef型別,

lambda表示式[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new AddNOp(context); }為生成AddNOp這個型別kernel的factory函式,對應引數create_fn。

class OpKernelRegistrar { public:………… // Registers the given factory function with TensorFlow。 This is equivalent // to registering a factory whose Create function invokes `create_fn`。 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, OpKernel* (*create_fn)(OpKernelConstruction*)) { // Perform the check in the header to allow compile-time optimization // to a no-op, allowing the linker to remove the kernel symbols。 if (kernel_def != nullptr) { InitInternal(kernel_def, kernel_class_name, absl::make_unique(create_fn)); } }

檔案op_kernel。cc中的OpKernelRegistrar::InitInternal完成把新建立的KernelRegistration物件插入到全域性的global_registry->registry map中去。

後面如果計算圖中如果需要它,會透過kernel的key來獲取物件KernelRegistration,然後呼叫它的factory來產生AddNOp例項,然後賦值給節點node。

struct KernelRegistry { mutex mu; std::unordered_multimap registry GUARDED_BY(mu);}; void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, std::unique_ptr factory) { // See comments in register_kernel::Name in header for info on _no_register。 if (kernel_def->op() != “_no_register”) { const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label()); // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped // here。 // InitInternal gets called by static initializers, so it ends up executing // before main。 This causes LoadKernelLibraries function to get called // before some file libraries can initialize, which in turn crashes the // program flakily。 Until we get rid of static initializers in kernel // registration mechanism, we have this workaround here。 auto global_registry = reinterpret_cast(GlobalKernelRegistry()); mutex_lock l(global_registry->mu); global_registry->registry。emplace( key, KernelRegistration(*kernel_def, kernel_class_name, std::move(factory))); } delete kernel_def;}