2019 06 21 tensorflow eager
layout: post title: "tensorflow_eager" subtitle: "tensorflow_eager" date: 2019-06-21 14:22:49 author: "none" header-img: "img/posts/default_post.jpg" catalog: true tags: - tag
pywrap_tensorflow
# tensorflow\__init__.py
from tensorflow.python import pywrap_tensorflow
# tensorflow\python\__init__.py
from tensorflow.python import pywrap_tensorflow
# tensorflow\python\pywrap_tensorflow.py
from tensorflow.python.pywrap_tensorflow_internal import *
swig 根据*.i文件生成 1. pywrap_tensorflow_internal.cc文件: 包裹c api, 生成动态链接库.so 2. pywrap_tensorflow_internal.py文件: python编程的接口文件, 加载动态链接库
例如使用swig将c源文件中的TFE_Py_Execute函数给python使用
// tensorflow\python\tensorflow.i
%include "tensorflow/python/pywrap_tfe.i"
// tensorflow\python\pywrap_tfe.i
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
%}
// 相应的swig语法,映射python和c的类型
TFE_Py_Execute函数的实现如下:
// tensorflow\python\eager\pywrap_tfe.h
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
const char* op_name, TFE_InputTensorHandles* inputs,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status);
// tensorflow\python\eager\pywrap_tfe_src.cc
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
const char* op_name, TFE_InputTensorHandles* inputs,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status) {
TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
if (TF_GetCode(out_status) != TF_OK) return;
TFE_OpSetDevice(op, device_name, out_status);
if (TF_GetCode(out_status) == TF_OK) {
for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
++i) {
TFE_OpAddInput(op, inputs->at(i), out_status);
}
}
if (TF_GetCode(out_status) == TF_OK) {
SetOpAttrs(ctx, op, attrs, 0, out_status);
}
Py_BEGIN_ALLOW_THREADS;
if (TF_GetCode(out_status) == TF_OK) {
int num_outputs = outputs->size();
TFE_Execute(op, outputs->data(), &num_outputs, out_status);
outputs->resize(num_outputs);
}
if (TF_GetCode(out_status) != TF_OK) {
TF_SetStatus(out_status, TF_GetCode(out_status),
tensorflow::strings::StrCat(TF_Message(out_status),
" [Op:", op_name, "]")
.c_str());
}
TFE_DeleteOp(op);
Py_END_ALLOW_THREADS;
}
在python端,可以使用如下代码调用c api中的TFE_Py_Execute函数
# tensorflow\python\eager\execute.py
tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
op_name, inputs, attrs,
num_outputs)
- TFE_NewOp: c api, 函数
- TFE_Op: c api internal, 结构体
- EagerOperation: c++, 类
// tensorflow\c\eager\c_api.cc
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_Op(ctx, name, false, types,
new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
return new TFE_Op(ctx, name, true, types, nullptr);
}
// tensorflow\c\eager\c_api_internal.h
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx)
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
EagerOperation
tensorflow\core\common_runtime\eager\eager_operation.h
EagerTensor
EagerTensor为python类, 在_EagerTensorBase类的基础上经过c api处理生成
# tensorflow\python\framework\ops.py
from tensorflow.python import pywrap_tensorflow as c_api
# TODO(agarwal): consider getting rid of this.
class _EagerTensorBase(Tensor):
"""Base class for EagerTensor."""
# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
# registers it with the current module.
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
eager context
tensorflow\python\eager\context.py
def _create_context():
global _context
with _context_lock:
if _context is None:
_context = Context()
def context():
"""Returns a singleton context object."""
if _context is None:
_create_context()
return _context
enable_eager_execution
tensorflow\python\framework\ops.py
@tf_export(v1=["enable_eager_execution"])
def enable_eager_execution(config=None, device_policy=None,
execution_mode=None):
_api_usage_gauge.get_cell().set(True)
if context.default_execution_mode != context.EAGER_MODE:
return enable_eager_execution_internal(
config=config,
device_policy=device_policy,
execution_mode=execution_mode,
server_def=None)