Failed to import pydot. You must `pip install pydot` and install graphviz
我们在调用keras里面的高级API——plot_model(),去画神经网络的结构图的时候可能会遇到两个报错问题。
第一个是说keras.utils里面不存在plot_model()这个用法。
cannot import name 'plot_model' from 'keras.utils'
这个问题好解决,因为keras里面确实没有plot_model()用法,但是他的好兄弟——TensorFlow里面有.....
直接这样导入:
from tensorflow.keras.utils import plot_model
就可以了。
第二个报错问题是:
Failed to import pydot. You must `pip install pydot` and install graphviz
意思是缺失两个包,一个pydot,一个graphviz。
我查了很多文章,很多方法比较麻烦,都需要手动下载,手动配置环境变量,后来看到一个很简单的方法,并且也测试有效。
直接在anaconda prompt里面:
- conda install graphviz
- conda install pydotplus
就可以了,不过安装过程好像会给你装上很多额外的包.....不过不影响环境,神经网络还是一样能跑。
导入包,构建一个网络。这个网络是Model类,采用函数API实现,稍微复杂点,可以看图就会清楚他的结构。
导入包
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
-
- import tensorflow as tf
- import keras
- from keras.preprocessing import sequence
- from keras.models import Sequential,Model
- from keras.layers import Dense,Input, Dropout, Embedding, Flatten,MaxPooling1D,Conv1D,SimpleRNN,LSTM,GRU,Multiply
- from keras.layers import Bidirectional,Activation,BatchNormalization
- from keras.layers.merge import concatenate
-
- from keras.callbacks import EarlyStopping
- from tensorflow.keras import regularizers
- from keras.utils.np_utils import to_categorical
- from tensorflow.keras import optimizers
- from tensorflow.keras.utils import plot_model
定义模型:
- inputs = Input(name='inputs',shape=[64,100], dtype='float64')
- gru=Bidirectional(GRU(32,return_sequences=True,))(inputs)
- mlp = Dense(64,activation='relu')(gru)
- attention_probs = Dense(64, activation='softmax', name='attention_vec')(mlp)
- attention_mul = Multiply()([mlp, attention_probs])
- mlp = Dense(64)(attention_mul) #原始的全连接
- fla=Flatten()(mlp)
- output = Dense(2, activation='softmax')(fla)
- model = Model(inputs=[inputs], outputs=output)
- model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
画出图形:
plot_model(model,'new_model.png',show_shapes=True)
第一参数是神经网络模型,第二个参数是储存的图片名称,第三个是在图片上打印出每层的数据形状。

用这种图就能很方便的展示组建的模型的架构,多输入多输出都行。
show_shapes=True参数改为False,就可以简化展示图片,不打印形状。我这里换了一种结构的网络。
plot_model(model,'model2.png',show_shapes=False)
