diff --git a/examples/cnn_example.go b/examples/cnn_example.go index f2f08e6..30a06f7 100644 --- a/examples/cnn_example.go +++ b/examples/cnn_example.go @@ -65,15 +65,30 @@ func main() { fmt.Printf("展平后大小: %d\n", flattened.Size()) // 创建一些随机权重进行全连接层操作 - weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6} - weightsShape := []int{flattened.Size(), 2} // 输出2类(猫/狗) + // 由于池化后是2x2,展平后应该是4个元素,所以我们需要4x2的权重矩阵 + flattenedSize := flattened.Size() + weightsData := make([]float64, flattenedSize * 2) // flattenedSize*2个权重值 + for i := range weightsData { + weightsData[i] = 0.1 * float64(i+1) // 填充一些递增的值 + } + weightsShape := []int{flattenedSize, 2} // 输出2类(猫/狗) weights, err := gotensor.NewTensor(weightsData, weightsShape) if err != nil { panic(err) } - // 全连接层计算 (矩阵乘法) - fcResult, err := flattened.MatMul(weights) + // 重塑flattened张量为2D格式以进行矩阵乘法 + reshapedFlattenedData := make([]float64, flattenedSize) + for i := 0; i < flattenedSize; i++ { + reshapedFlattenedData[i], _ = flattened.Data.Get(i) + } + reshapedFlattened, err := gotensor.NewTensor(reshapedFlattenedData, []int{1, flattenedSize}) // 作为1xN的矩阵 + if err != nil { + panic(err) + } + + // 全连接层计算 (矩阵乘法) - 现在是 (1, N) * (N, 2) = (1, 2) + fcResult, err := reshapedFlattened.MatMul(weights) if err != nil { panic(err) }