gotensor/model.go

125 lines
2.4 KiB
Go

package gotensor
import (
"encoding/json"
"os"
)
// Model 模型接口定义
type Model interface {
Forward(inputs *Tensor) (*Tensor, error)
Parameters() []*Tensor // 获取模型所有参数
ZeroGrad() // 将所有参数的梯度清零
}
// Sequential 序列模型,按顺序执行层
type Sequential struct {
Layers []Layer
}
// Layer 接口定义
type Layer interface {
Forward(inputs *Tensor) (*Tensor, error)
Parameters() []*Tensor
ZeroGrad()
}
// Forward 实现前向传播
func (s *Sequential) Forward(inputs *Tensor) (*Tensor, error) {
output := inputs
var err error
for _, layer := range s.Layers {
output, err = layer.Forward(output)
if err != nil {
return nil, err
}
}
return output, nil
}
// Parameters 获取模型所有参数
func (s *Sequential) Parameters() []*Tensor {
var params []*Tensor
for _, layer := range s.Layers {
params = append(params, layer.Parameters()...)
}
return params
}
// ZeroGrad 将所有参数梯度清零
func (s *Sequential) ZeroGrad() {
for _, layer := range s.Layers {
layer.ZeroGrad()
}
}
// SaveModel 保存模型参数到文件
func SaveModel(model Model, filepath string) error {
params := model.Parameters()
paramsData := make([][]float64, len(params))
for i, param := range params {
shape := param.Shape()
size := param.Size()
data := make([]float64, size)
for idx := 0; idx < size; idx++ {
if len(shape) == 1 {
data[idx], _ = param.Data.Get(idx)
} else if len(shape) == 2 {
cols := shape[1]
data[idx], _ = param.Data.Get(idx/cols, idx%cols)
}
}
paramsData[i] = data
}
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
return json.NewEncoder(file).Encode(paramsData)
}
// LoadModel 从文件加载模型参数
func LoadModel(model Model, filepath string) error {
file, err := os.Open(filepath)
if err != nil {
return err
}
defer file.Close()
var paramsData [][]float64
err = json.NewDecoder(file).Decode(&paramsData)
if err != nil {
return err
}
params := model.Parameters()
if len(params) != len(paramsData) {
return nil // 参数数量不匹配,返回错误
}
for i, param := range params {
data := paramsData[i]
shape := param.Shape()
if len(shape) == 1 {
for idx, val := range data {
param.Data.Set(val, idx)
}
} else if len(shape) == 2 {
cols := shape[1]
for idx, val := range data {
param.Data.Set(val, idx/cols, idx%cols)
}
}
}
return nil
}