• 深度学习手写简单的lstm


    代码如下;

    import pandas as pd
    import numpy as np
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Activation, Flatten, LSTM, TimeDistributed, RepeatVector
    from keras.layers.normalization.batch_normalization_v1 import BatchNormalization
    from keras.optimizers import Adam
    from keras.callbacks import EarlyStopping, ModelCheckpoint
    from sklearn.preprocessing import StandardScaler, MinMaxScaler
    import matplotlib.pyplot as plt
    n_in = 168  # 历史数量
    n_out = 24  # 预测数量
    n_features = 1
    # n_test = 1
    n_val = 1
    n_epochs = 250
    
    
    # 导入数据
    def load_stw_data() -> pd.DataFrame:   #在def那一行后面会加一个->。这个玩意儿有个专门的名词叫 type hint, 即类型提示
        df_stw = pd.read_excel('data3.xlsx')
        df_stw.columns = ['BillingDate', 'VolumnHL']
    
        return df_stw
    
    
    # MinMaxScaler数据归一化,可以帮助网络模型更快的拟合,稍微有一些提高准确率的效果
    def minmaxscaler(data: pd.DataFrame) -> pd.DataFrame:
        volume = data.VolumnHL.values
        volume = volume.reshape(len(volume), 1)
        volume = scaler.fit_transform(volume)
        volume = volume.reshape(len(volume), )
        data['VolumnHL'] = volume
    
        return data
    
    
    # 划分训练数据集和验证数据集,这里需要注意的是我么需要预测的数据是不可以出现在训练中的,切记。
    def split_data(x, y, n_test: int):
        x_train = x[:-n_val - n_out + 1]
        x_val = x[-n_val:]
        y_train = y[:-n_val - n_out + 1]
        y_val = y[-n_val:]
    
        return x_train, y_train, x_val, y_val
    
    
    # 划分X和Y
    def build_train(train, n_in, n_out):
        train = train.drop(["BillingDate"], axis=1)
        X_train, Y_train = [], []
        for i in range(train.shape[0] - n_in - n_out + 1):
            X_train.append(np.array(train.iloc[i:i + n_in]))
            Y_train.append(np.array(train.iloc[i + n_in:i + n_in + n_out]["VolumnHL"]))
    
        return np.array(X_train), np.array(Y_train)
    
    
    # 构建最简单的LSTM
    def build_lstm(n_in: int, n_features: int):
        model = Sequential()
        model.add(LSTM(12, activation='relu', input_shape=(n_in, n_features)))
        model.add(Dropout(0.3))
        model.add(Dense(n_out))
        model.compile(optimizer='adam', loss='mae')
    
        return model
    
    
    # 模型拟合
    def model_fit(x_train, y_train, x_val, y_val, n_features):
        model = build_lstm(
            n_in=n_in,
            n_features=1
        )
        model.compile(loss='mae', optimizer='adam')
        model.fit(x_train, y_train, epochs=n_epochs, batch_size=128, verbose=1, validation_data=(x_val, y_val))
        m = model.evaluate(x_val, y_val)
        print(m)
    
        return model
    
    
    data = load_stw_data()
    scaler = MinMaxScaler(feature_range=(0, 1))
    data = minmaxscaler(data)
    data_copy = data.copy()
    x, y = build_train(data_copy, n_in, n_out)
    x_train, y_train, x_val, y_val = split_data(x, y, n_val)
    model = build_lstm(n_in, 1)
    model = model_fit(x_train, y_train, x_val, y_val, 1)
    predict = model.predict(x_val)
    # predict = model.predict(x_val)
    validation = scaler.inverse_transform(predict)[0]
    validation
    actual = scaler.inverse_transform(y_val)[0]
    actual
    predict = validation
    actual = actual
    x = [x for x in range(24)]
    fig, ax = plt.subplots(figsize=(15,5),dpi = 300)
    ax.plot(x, predict, linewidth=2.0,label = "predict")
    ax.plot(x, actual, linewidth=2.0,label = "actual")
    ax.legend(loc=2);
    # ax.set_title(bf_name)
    plt.ylim((0, 900000))
    plt.grid(linestyle='-.')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107

    在这里插入图片描述

  • 相关阅读:
    torch.nn中LSTM使用
    【LeetCode】69. x 的平方根
    认识vue3以及语法运用简介
    基于PaddleOCR的多视角集装箱箱号检测识别
    万字string类总结
    postman基础
    身份证号码,格式校验:@IdCard(Validation + Hutool)
    国标视频融合云平台EasyCVR视频汇聚平台关于远程控制的详细介绍
    Git版本控制管理——版本库管理
    QT的TCP连接功能概述
  • 原文地址:https://blog.csdn.net/qq_45706306/article/details/125627894