TensorFlow op和op kernel的註冊
本文連結地址:
「連結」
Content:
生成ops的定義
生成op kernels的factory定義
回目錄
生成ops的定義
TensorFlow Library在載入的時候,其中so裡面的靜態變數會例項化。
以AddN這個op為例,在檔案math_ops。cc裡面的REGISTER_OP(“AddN”)其實是個靜態變數定義:
REGISTER_OP(“AddN”) 。Input(“inputs: N * T”) 。Output(“sum: T”) 。Attr(“N: int >= 1”) 。Attr(“T: {numbertype, variant}”) 。SetIsCommutative() 。SetIsAggregate() 。SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); ………… });
REGISTER_OP(“AddN”)展開後為:
static ::tensorflow::register_op::OpDefBuilderReceiver register_op165 \ __attribute__((unused)) = \ ::tensorflow::register_op::OpDefBuilderWrapper
在檔案op。cc中有OpDefBuilderReceiver的構造方法:
namespace register_op {OpDefBuilderReceiver::OpDefBuilderReceiver( const OpDefBuilderWrapper
透過檔案op。cc中的OpRegistry::Register呼叫OpRegistry::RegisterAlreadyLocked來使op_data_factory產生op的定義OpRegistrationData。
static OpRegistry* OpRegistry::Global() { static OpRegistry* global_op_registry = new OpRegistry; return global_op_registry;} void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { mutex_lock lock(mu_); if (initialized_) { TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); } else { deferred_。push_back(op_data_factory); }} mutable std::unordered_map
透過檔案op_def_builder。cc中的OpDefBuilder::Finalize來給OpRegistrationData賦值。
Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { std::vector
以後使用的時候就可以透過檔案op。cc中的OpRegistry::LookUp獲取到op type對應的定義。
Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { { tf_shared_lock l(mu_); if (initialized_) { if (const OpRegistrationData* res = gtl::FindWithDefault(registry_, op_type_name, nullptr)) { *op_reg_data = res; return Status::OK(); } } } return LookUpSlow(op_type_name, op_reg_data);}
生成op kernels的factory定義
同樣,也是在TensorFlow Library載入的時候。
還是以AddN這個op為例,在檔案aggregate_ops。cc裡面,關於AddNOp這個op_kernel相關程式碼如下:
template
TF_CALL_NUMBER_TYPES宏展開後如下,它會對所以的資料型別都分別建立一個OpKernelRegistrar物件:
constexpr bool should_register_389__flag = \ true; \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__389__object( \ should_register_389__flag \ ? ::tensorflow::register_kernel::Name(“AddN”)。Device(DEVICE_CPU)。TypeConstraint<::tensorflow::int64>(“T”)。Build() \ : nullptr, \ “AddNOp
op_kernel。h中有上面OpKernelRegistrar的建構函式定義:
::tensorflow::register_kernel::Name(“AddN”)。Device(DEVICE_CPU)。TypeConstraint<::tensorflow::int64>(“T”)。Build()對應KernelDef型別,
lambda表示式[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new AddNOp(context); }為生成AddNOp這個型別kernel的factory函式,對應引數create_fn。
class OpKernelRegistrar { public:………… // Registers the given factory function with TensorFlow。 This is equivalent // to registering a factory whose Create function invokes `create_fn`。 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, OpKernel* (*create_fn)(OpKernelConstruction*)) { // Perform the check in the header to allow compile-time optimization // to a no-op, allowing the linker to remove the kernel symbols。 if (kernel_def != nullptr) { InitInternal(kernel_def, kernel_class_name, absl::make_unique
檔案op_kernel。cc中的OpKernelRegistrar::InitInternal完成把新建立的KernelRegistration物件插入到全域性的global_registry->registry map中去。
後面如果計算圖中如果需要它,會透過kernel的key來獲取物件KernelRegistration,然後呼叫它的factory來產生AddNOp例項,然後賦值給節點node。
struct KernelRegistry { mutex mu; std::unordered_multimap