222 lines
5.7 KiB
Go
222 lines
5.7 KiB
Go
package main
|
||
|
||
import (
|
||
"fmt"
|
||
|
||
"git.kingecg.top/kingecg/gotensor"
|
||
)
|
||
|
||
// LinearLayer 是一个简单的线性层实现
|
||
type LinearLayer struct {
|
||
Weight *gotensor.Tensor
|
||
Bias *gotensor.Tensor
|
||
}
|
||
|
||
// NewLinearLayer 创建一个新的线性层
|
||
func NewLinearLayer(inputSize, outputSize int) *LinearLayer {
|
||
// 初始化权重和偏置
|
||
weight, _ := gotensor.NewTensor([]float64{
|
||
0.1, 0.2,
|
||
0.3, 0.4,
|
||
}, []int{outputSize, inputSize})
|
||
|
||
bias, _ := gotensor.NewTensor([]float64{0.1, 0.1}, []int{outputSize})
|
||
|
||
return &LinearLayer{
|
||
Weight: weight,
|
||
Bias: bias,
|
||
}
|
||
}
|
||
|
||
func (l *LinearLayer) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error) {
|
||
// 执行线性变换: output = inputs * weight^T + bias
|
||
weightTransposed, err := l.Weight.Data.Transpose()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 创建转置后的权重张量
|
||
weightTransposedTensor := &gotensor.Tensor{
|
||
Data: weightTransposed,
|
||
Grad: must(gotensor.NewZeros(l.Weight.Shape())),
|
||
}
|
||
|
||
mulResult, err := inputs.MatMul(weightTransposedTensor)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
output, err := mulResult.Add(l.Bias)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return output, nil
|
||
}
|
||
|
||
func (l *LinearLayer) Parameters() []*gotensor.Tensor {
|
||
return []*gotensor.Tensor{l.Weight, l.Bias}
|
||
}
|
||
|
||
func (l *LinearLayer) ZeroGrad() {
|
||
l.Weight.ZeroGrad()
|
||
l.Bias.ZeroGrad()
|
||
}
|
||
|
||
// SimpleModel 是一个简单的模型实现
|
||
type SimpleModel struct {
|
||
Layer *LinearLayer
|
||
}
|
||
|
||
func (m *SimpleModel) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error) {
|
||
return m.Layer.Forward(inputs)
|
||
}
|
||
|
||
func (m *SimpleModel) Parameters() []*gotensor.Tensor {
|
||
return m.Layer.Parameters()
|
||
}
|
||
|
||
func (m *SimpleModel) ZeroGrad() {
|
||
m.Layer.ZeroGrad()
|
||
}
|
||
|
||
// must 是一个辅助函数,用于处理可能的错误
|
||
func must(t *gotensor.Tensor, err error) *gotensor.Tensor {
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
return t
|
||
}
|
||
|
||
func main() {
|
||
fmt.Println("Gotensor Advanced Optimizer Example")
|
||
|
||
// 创建模型
|
||
model := &SimpleModel{
|
||
Layer: NewLinearLayer(2, 2), // 2输入,2输出
|
||
}
|
||
|
||
fmt.Println("比较不同优化器的性能:")
|
||
|
||
// 准备训练数据 (简单的线性回归问题)
|
||
trainInputs := []*gotensor.Tensor{
|
||
must(gotensor.NewTensor([]float64{1, 0}, []int{2})),
|
||
must(gotensor.NewTensor([]float64{0, 1}, []int{2})),
|
||
must(gotensor.NewTensor([]float64{1, 1}, []int{2})),
|
||
must(gotensor.NewTensor([]float64{0, 0}, []int{2})),
|
||
}
|
||
|
||
trainTargets := []*gotensor.Tensor{
|
||
must(gotensor.NewTensor([]float64{2, 0}, []int{2})),
|
||
must(gotensor.NewTensor([]float64{0, 2}, []int{2})),
|
||
must(gotensor.NewTensor([]float64{2, 2}, []int{2})),
|
||
must(gotensor.NewTensor([]float64{0, 0}, []int{2})),
|
||
}
|
||
|
||
// 定义损失函数 (MSE)
|
||
lossFn := func(output, target *gotensor.Tensor) *gotensor.Tensor {
|
||
// 计算均方误差
|
||
diff, _ := output.Sub(target)
|
||
squared, _ := diff.Mul(diff)
|
||
sum, _ := squared.Sum()
|
||
size := float64(output.Size())
|
||
result, _ := sum.DivScalar(size)
|
||
return result
|
||
}
|
||
|
||
// 测试SGD优化器
|
||
fmt.Println("\n1. 使用SGD优化器训练:")
|
||
sgdModel := &SimpleModel{
|
||
Layer: NewLinearLayer(2, 2),
|
||
}
|
||
sgdOptimizer := gotensor.NewSGD(sgdModel.Parameters(), 0.1)
|
||
sgdTrainer := gotensor.NewTrainer(sgdModel, sgdOptimizer)
|
||
|
||
sgdInitialLoss, _ := sgdTrainer.Evaluate(trainInputs, trainTargets, lossFn)
|
||
fmt.Printf("初始损失: %.6f\n", sgdInitialLoss)
|
||
|
||
err := sgdTrainer.Train(trainInputs, trainTargets, 100, lossFn, false)
|
||
if err != nil {
|
||
fmt.Printf("SGD训练过程中出现错误: %v\n", err)
|
||
return
|
||
}
|
||
|
||
sgdFinalLoss, _ := sgdTrainer.Evaluate(trainInputs, trainTargets, lossFn)
|
||
fmt.Printf("SGD最终损失: %.6f\n", sgdFinalLoss)
|
||
|
||
// 测试Adam优化器
|
||
fmt.Println("\n2. 使用Adam优化器训练:")
|
||
adamModel := &SimpleModel{
|
||
Layer: NewLinearLayer(2, 2),
|
||
}
|
||
adamOptimizer := gotensor.NewAdam(adamModel.Parameters(), 0.01, 0.9, 0.999, 1e-8)
|
||
adamTrainer := gotensor.NewTrainer(adamModel, adamOptimizer)
|
||
|
||
adamInitialLoss, _ := adamTrainer.Evaluate(trainInputs, trainTargets, lossFn)
|
||
fmt.Printf("初始损失: %.6f\n", adamInitialLoss)
|
||
|
||
err = adamTrainer.Train(trainInputs, trainTargets, 100, lossFn, false)
|
||
if err != nil {
|
||
fmt.Printf("Adam训练过程中出现错误: %v\n", err)
|
||
return
|
||
}
|
||
|
||
adamFinalLoss, _ := adamTrainer.Evaluate(trainInputs, trainTargets, lossFn)
|
||
fmt.Printf("Adam最终损失: %.6f\n", adamFinalLoss)
|
||
|
||
// 比较两个模型的预测结果
|
||
fmt.Println("\n比较两个模型的预测结果:")
|
||
testInput := must(gotensor.NewTensor([]float64{0.5, 0.5}, []int{2}))
|
||
|
||
sgdOutput, _ := sgdModel.Forward(testInput)
|
||
adamOutput, _ := adamModel.Forward(testInput)
|
||
|
||
sgdOut0, _ := sgdOutput.Data.Get(0)
|
||
sgdOut1, _ := sgdOutput.Data.Get(1)
|
||
adamOut0, _ := adamOutput.Data.Get(0)
|
||
adamOut1, _ := adamOutput.Data.Get(1)
|
||
|
||
fmt.Printf("输入: [0.5, 0.5]\n")
|
||
fmt.Printf("SGD输出: [%.6f, %.6f]\n", sgdOut0, sgdOut1)
|
||
fmt.Printf("Adam输出: [%.6f, %.6f]\n", adamOut0, adamOut1)
|
||
|
||
// 演示手动使用优化器
|
||
fmt.Println("\n3. 演示手动使用优化器:")
|
||
manualModel := &SimpleModel{
|
||
Layer: NewLinearLayer(2, 2),
|
||
}
|
||
manualOptimizer := gotensor.NewAdam(manualModel.Parameters(), 0.01, 0.9, 0.999, 1e-8)
|
||
|
||
// 执行几个训练步骤
|
||
for step := 0; step < 5; step++ {
|
||
totalLoss := 0.0
|
||
for i := 0; i < len(trainInputs); i++ {
|
||
// 前向传播
|
||
output, err := manualModel.Forward(trainInputs[i])
|
||
if err != nil {
|
||
fmt.Printf("前向传播错误: %v\n", err)
|
||
return
|
||
}
|
||
|
||
// 计算损失
|
||
loss := lossFn(output, trainTargets[i])
|
||
lossVal, _ := loss.Data.Get(0)
|
||
totalLoss += lossVal
|
||
|
||
// 反向传播
|
||
loss.Backward()
|
||
|
||
// 更新参数
|
||
manualOptimizer.Step()
|
||
|
||
// 清空梯度
|
||
manualOptimizer.ZeroGrad()
|
||
}
|
||
|
||
avgLoss := totalLoss / float64(len(trainInputs))
|
||
fmt.Printf("步骤 %d, 平均损失: %.6f\n", step+1, avgLoss)
|
||
}
|
||
|
||
fmt.Println("\n优化器示例完成!")
|
||
}
|