• CNN 卷积神经网络day3


    学习来源:https://blog.csdn.net/minfanphd/article/details/116974889

    一、网络构建

    1、initOperators 初始化若干算子。 并且注意到它们与已经初始化的成员变量有关。
    2、ALPHA 和 LAMBDA 是超参数, 可以自己设置。
    3、setup 进行整个网络的初始化。
    4、forward 和 backPropagation 与 ANN 同理, 但运算不同。
    5、一批数据进行 forward 和 backPropagation 后, 才进行一次 updateParameters。
    6、部分代码。

    package CNN;
    /**
     * @time 2022/6/26
     * @author Liang Huang
     */
     
    import java.util.Arrays;
    import machinelearning.cnn.Dataset.Instance;
    import machinelearning.cnn.MathUtils.Operator;
    
    public class FullCnn {
    	/**
    	 * The value changes.
    	 */
    	private static double ALPHA = 0.85;
    
    	/**
    	 * A constant.
    	 */
    	public static double LAMBDA = 0;
    
    	/**
    	 * Manage layers.
    	 */
    	private static LayerBuilder layerBuilder;
    
    	/**
    	 * Train using a number of instances simultaneously.
    	 */
    	private int batchSize;
    
    	/**
    	 * Divide the batch size with the given value.
    	 */
    	private Operator divideBatchSize;
    
    	/**
    	 * Multiply alpha with the given value.
    	 */
    	private Operator multiplyAlpha;
    
    	/**
    	 * Multiply lambda and alpha with the given value.
    	 */
    	private Operator multiplyLambda;
    
    	/**
    	 *********************** 
    	 * The first constructor.
    	 * 
    	 *********************** 
    	 */
    	public FullCnn(LayerBuilder paraLayerBuilder, int paraBatchSize) {
    		layerBuilder = paraLayerBuilder;
    		batchSize = paraBatchSize;
    		setup();
    		initOperators();
    	}// Of the first constructor
    
    	/**
    	 *********************** 
    	 * Initialize operators using temporary classes.
    	 *********************** 
    	 */
    	private void initOperators() {
    		divideBatchSize = new Operator() {
    			private static final long serialVersionUID = 7424011281732651055L;
    
    			@Override
    			public double process(double value) {
    				return value / batchSize;
    			}// Of process
    		};
    
    		multiplyAlpha = new Operator() {
    			private static final long serialVersionUID = 5761368499808006552L;
    
    			@Override
    			public double process(double value) {
    				return value * ALPHA;
    			}// Of process
    		};
    
    		multiplyLambda = new Operator() {
    			private static final long serialVersionUID = 4499087728362870577L;
    
    			@Override
    			public double process(double value) {
    				return value * (1 - LAMBDA * ALPHA);
    			}// Of process
    		};
    	}// Of initOperators
    
    	/**
    	 *********************** 
    	 * Setup according to the layer builder.
    	 *********************** 
    	 */
    	public void setup() {
    		CnnLayer tempInputLayer = layerBuilder.getLayer(0);
    		tempInputLayer.initOutMaps(batchSize);
    
    		for (int i = 1; i < layerBuilder.getNumLayers(); i++) {
    			CnnLayer tempLayer = layerBuilder.getLayer(i);
    			CnnLayer tempFrontLayer = layerBuilder.getLayer(i - 1);
    			int tempFrontMapNum = tempFrontLayer.getOutMapNum();
    			switch (tempLayer.getType()) {
    			case INPUT:
    				// Should not be input. Maybe an error should be thrown out.
    				break;
    			case CONVOLUTION:
    				tempLayer.setMapSize(
    						tempFrontLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1));
    				tempLayer.initKernel(tempFrontMapNum);
    				tempLayer.initBias();
    				tempLayer.initErrors(batchSize);
    				tempLayer.initOutMaps(batchSize);
    				break;
    			case SAMPLING:
    				tempLayer.setOutMapNum(tempFrontMapNum);
    				tempLayer.setMapSize(tempFrontLayer.getMapSize().divide(tempLayer.getScaleSize()));
    				tempLayer.initErrors(batchSize);
    				tempLayer.initOutMaps(batchSize);
    				break;
    			case OUTPUT:
    				tempLayer.initOutputKernel(tempFrontMapNum, tempFrontLayer.getMapSize());
    				tempLayer.initBias();
    				tempLayer.initErrors(batchSize);
    				tempLayer.initOutMaps(batchSize);
    				break;
    			}// Of switch
    		} // Of for i
    	}// Of setup
    
    	/**
    	 *********************** 
    	 * Forward computing.
    	 *********************** 
    	 */
    	private void forward(Instance instance) {
    		setInputLayerOutput(instance);
    		for (int l = 1; l < layerBuilder.getNumLayers(); l++) {
    			CnnLayer tempCurrentLayer = layerBuilder.getLayer(l);
    			CnnLayer tempLastLayer = layerBuilder.getLayer(l - 1);
    			switch (tempCurrentLayer.getType()) {
    			case CONVOLUTION:
    				setConvolutionOutput(tempCurrentLayer, tempLastLayer);
    				break;
    			case SAMPLING:
    				setSampOutput(tempCurrentLayer, tempLastLayer);
    				break;
    			case OUTPUT:
    				setConvolutionOutput(tempCurrentLayer, tempLastLayer);
    				break;
    			default:
    				break;
    			}// Of switch
    		} // Of for l
    	}// Of forward
    
    	/**
    	 *********************** 
    	 * Set the in layer output. Given a record, copy its values to the input
    	 * map.
    	 *********************** 
    	 */
    	private void setInputLayerOutput(Instance paraRecord) {
    		CnnLayer tempInputLayer = layerBuilder.getLayer(0);
    		Size tempMapSize = tempInputLayer.getMapSize();
    		double[] tempAttributes = paraRecord.getAttributes();
    		if (tempAttributes.length != tempMapSize.width * tempMapSize.height)
    			throw new RuntimeException("input record does not match the map size.");
    
    		for (int i = 0; i < tempMapSize.width; i++) {
    			for (int j = 0; j < tempMapSize.height; j++) {
    				tempInputLayer.setMapValue(0, i, j, tempAttributes[tempMapSize.height * i + j]);
    			} // Of for j
    		} // Of for i
    	}// Of setInputLayerOutput
    
    	/**
    	 *********************** 
    	 * Compute the convolution output according to the output of the last layer.
    	 * 
    	 * @param paraLastLayer the last layer.
    	 * @param paraLayer the current layer.
    	 *********************** 
    	 */
    	private void setConvolutionOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {
    		// int mapNum = paraLayer.getOutMapNum();
    		final int lastMapNum = paraLastLayer.getOutMapNum();
    
    		// Attention: paraLayer.getOutMapNum() may not be right.
    		for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
    			double[][] tempSumMatrix = null;
    			for (int i = 0; i < lastMapNum; i++) {
    				double[][] lastMap = paraLastLayer.getMap(i);
    				double[][] kernel = paraLayer.getKernel(i, j);
    				if (tempSumMatrix == null) {
    					// On the first map.
    					tempSumMatrix = MathUtils.convnValid(lastMap, kernel);
    				} else {
    					// Sum up convolution maps
    					tempSumMatrix = MathUtils.matrixOp(MathUtils.convnValid(lastMap, kernel),
    							tempSumMatrix, null, null, MathUtils.plus);
    				} // Of if
    			} // Of for i
    
    			// Activation.
    			final double bias = paraLayer.getBias(j);
    			tempSumMatrix = MathUtils.matrixOp(tempSumMatrix, new Operator() {
    				private static final long serialVersionUID = 2469461972825890810L;
    
    				@Override
    				public double process(double value) {
    					return MathUtils.sigmod(value + bias);
    				}
    
    			});
    
    			paraLayer.setMapValue(j, tempSumMatrix);
    		} // Of for j
    	}// Of setConvolutionOutput
    
    	/**
    	 *********************** 
    	 * Compute the convolution output according to the output of the last layer.
    	 * 
    	 * @param paraLastLayer the last layer.
    	 * @param paraLayer the current layer.
    	 *********************** 
    	 */
    	private void setSampOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {
    		// int tempLastMapNum = paraLastLayer.getOutMapNum();
    
    		// Attention: paraLayer.outMapNum may not be right.
    		for (int i = 0; i < paraLayer.outMapNum; i++) {
    			double[][] lastMap = paraLastLayer.getMap(i);
    			Size scaleSize = paraLayer.getScaleSize();
    			double[][] sampMatrix = MathUtils.scaleMatrix(lastMap, scaleSize);
    			paraLayer.setMapValue(i, sampMatrix);
    		} // Of for i
    	}// Of setSampOutput
    	
    	/**
    	 *********************** 
    	 * The main entrance.
    	 *********************** 
    	 */
    	public static void main(String[] args) {
    	
    	}// Of main
    }// Of class MfCnn
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
  • 相关阅读:
    Meta开源新工具啊,Git地位危险了?
    章节十三:协程实践
    ZenCart 如何设置多个地区多个运费
    编译buildroot出错,这个怎么解决呢,感谢
    2022-08-27 第五组 张明敏 学习笔记
    多GPU训练的实现
    牛客网刷题篇
    阿里面试败北:5种微服务注册中心如何选型?这几个维度告诉你
    2022年最火的十大测试工具,你掌握了几个
    (附源码)springboot校园商铺系统 毕业设计 052145
  • 原文地址:https://blog.csdn.net/qq_44950283/article/details/125471717