Keras vs PyTorch

Keras vs PyTorch

下面做一个简单的比较,顺便跟Tensorflow比较一下,

  Keras PyTorch TensorFlow
API High Low High and Low
架构(学习难度) Simple, concise, readable Complex, less readable Not easy to use
调试 Simple network, so debugging is not often needed Good debugging capabilities Difficult to conduct debugging
性能 Slow, low performance Fast, high-performance Fast, high-performance
语言 Python C,Python,底层是依赖Torch C++, CUDA, Python
后台 TensorFlow, Theano and Microsoft CNTK backend 没有 没有

Keras相对简单一点,下面直观的感觉两者代码的差别。2段代码创建2个卷积层,relu激活,2个pooling层,最后softmax。

Keras

model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPool2D())
model.add(Conv2D(16, (3, 3), activation='relu'))
model.add(MaxPool2D())
model.add(Flatten())
model.add(Dense(10, activation='softmax'))

PyTorch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 10)
        self.pool = nn.MaxPool2d(2, 2)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 6 * 6)
        x = F.log_softmax(self.fc1(x), dim=-1)
        return x
model = Net()

下面给出所有的学习框架的流行度,

这个图是2018的统计结果。

下面比较一下输出模型。

PyTorch输出格式为Pickles, 它是基于python的,不是portable的,而 Keras使用是JSON + H5是portable的。
如果Keras底层是Tensorflow, Keras可以通过TensorFlow for Mobile和TensorFlow Lite将训练出来的模型很容易的到移动端. 或者通过 TensorFlow.js,keras.js部署到web端.
输出PyTorch 相对困难, 目前可行的方法是PyTorch输出model然后使用ONNX,将模型转换成Caffe2。

看一下性能的benchmark

    分享到:

留言

你的邮箱是保密的 必填的信息用*表示