Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C-API for inference. #1062

Merged
merged 57 commits into from
Apr 21, 2017
Merged

C-API for inference. #1062

merged 57 commits into from
Apr 21, 2017

Conversation

reyoung
Copy link
Collaborator

@reyoung reyoung commented Jan 4, 2017

@@ -0,0 +1,49 @@
#include "PaddleCAPI.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看到大家都把这个目录叫做 "c", 这个作为参考

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/dmlc/mxnet/tree/master/src/c_api

一般项目目录第一层是编码语言,就比如Paddle

-- paddle
 |- python
 |- proto

所以有一些项目的会把这部分放到c目录下。而对于Paddle,一来C/CPP区分界限不明显,二来但就这个C-API来说,本身也是C与C++的混合写法。所以很难独立成C目录。

@@ -0,0 +1,54 @@
#ifndef __PADDLE_PADDLE_CAPI_PADDLECAPI_H_INCLUDED__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#pragma once 我们的规范貌似是这个?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个因为是会被用户使用的库,所以 还是用 #if 作为guard把。

@@ -0,0 +1,27 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后边的test和前面的关系不大是么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,这个文件是勿提交的

extern "C" {
#endif

#define PD_NO_ERROR 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是定义一个enum更好

typedef enum {
  PD_NO_ERROR = 0,
  PD_NULLPTR = 1
} PD_Code;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是。

@reyoung
Copy link
Collaborator Author

reyoung commented Jan 11, 2017

@wangkuiyi @hedaoyuan @jacquesqiao This C-API is only used for model inference now. And it is not part of current API design. However, neural network inference API is an urgent issue now.

It just exposes some RAW Paddle APIs to C.

Also change previous design file, make concept consistant.
@reyoung reyoung requested a review from gangliao March 26, 2017 09:10
@reyoung
Copy link
Collaborator Author

reyoung commented Mar 26, 2017

麻烦 @wangkuiyi @jacquesqiao @hedaoyuan @gangliao @QiJune 来Review一下这个PR吧。设计文档在 这里

@reyoung reyoung requested review from gangliao and removed request for gangliao March 27, 2017 03:06
*width = m->width();
*height = m->height();
}
```

其中`paddle/math/matrix.hpp`文件内容为:
其中`paddle/capi/CMatrix.hpp`文件内容为:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件名采用哪种命名呢,代码里面是matrix.hMatrix.cpp,也需要统一一下吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名确实是统一的。。

对于C的header,为全小写加下划线形式。
对于CPP的source,为大写间隔。

只是这两个文件面向的语言不一样,因而采用更适合那个语言的命名风格。C语言的函数命名同理。


通常,这个结构体包含两个项目。

* `type`是一个类型的标志。对于每种类型,type字段均不尽相同。这样,即使C-API接受的类型全是`void *`,我们也可以确定每一个参数的类型。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type在目前的代码里面好像没有起到什么作用,建议在cast之后检查一下type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前确实没用,但是每次cast都检查也有点难受。毕竟检查都是要耗时的。。

这个type预留的目的还是为了有一定的多态性。。。

譬如,可以写一个 paddle_destroy(void*),这样可以删除任意类型的paddle对象。或者,paddle_tensor_get_data可以接受matrix和vector都行。。

@@ -70,20 +70,20 @@ extern "C"
paddle_error paddle_matrix_shape(paddle_matrix matrix,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名是不是统一使用paddle__动词词组比较好,代码里面是paddle_matrix_get_shape,arguments里面有几个函数不是。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


### libpaddle\_capi_shared.{so, dylib}

`libpaddle_capi_shared`是C-API导出的动态库。这个动态库的连接参数与Paddle的其他二进制(例如`paddle_traienr`)类似。用户可以直接使用这个动态库来引入Paddle C-API。具体使用方法为`-lpaddle_capi_shared`。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: paddle_traienr -> paddle_trainer


using paddle::capi::cast;

#define castArg(v) cast<paddle::capi::CArguments>(v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上文提到的,castArg时,是否需要检查type是不是kARGUMENTS

printf("%.2f ", array[i]);
}
printf("\n");

Copy link
Contributor

@Xreki Xreki Apr 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有调用paddle_matrix_destroy等接口释放资源。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,原来想程序退出了资源总会被OS回收,但是在example里面这么写确实不好。

NeuralNetwork* network) {
return new MyNeuralNetwork(name, network);
}
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MyNeuralNetwork定义在其他地方没有用到。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NerualNetwork本身不能直接new的。他的构造函数是私有的。。这里只是为了把它的构造函数expose出来。

}

paddle_error paddle_gradient_machine_create_shared_param(
paddle_gradient_machine origin,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CGradientMachine是否可以添加configsize这两个成员,这样这个接口可以简化成

paddle_error paddle_gradient_machine_create_shared_param(paddle_gradient_machine origin, 
                                                         paddle_gradient_machine* slave)

也能保证origin肯定用到的是正确的config

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

share参数并不一定非要配置一样。有可能是两个不同的网络结构,但是参数有共享。

paddle_arguments args = paddle_arguments_create_none();
ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_resize(args, 1));

paddle_matrix mat = paddle_matrix_create(128, 64, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有GPU的单测?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实没有,这就加一下。

@@ -0,0 +1,46 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面那个text_Init.cpp是空的,是误加还是尚未实现?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

误加的

@reyoung reyoung changed the title C-API for predict. C-API for inference. Apr 13, 2017
@reyoung reyoung merged commit f979b12 into PaddlePaddle:develop Apr 21, 2017
@reyoung reyoung deleted the feature/c_api branch April 25, 2017 12:34
zhhsplendid pushed a commit to zhhsplendid/Paddle that referenced this pull request Sep 25, 2019
wangxicoding pushed a commit to wangxicoding/Paddle that referenced this pull request Dec 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants