72 lines
1.5 KiB
Go
72 lines
1.5 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"git.kingecg.top/kingecg/gotensor"
|
|
)
|
|
|
|
func main() {
|
|
fmt.Println("=== 基本运算示例 ===")
|
|
|
|
// 创建两个2x2的张量
|
|
t1_data := []float64{1, 2, 3, 4}
|
|
t1_shape := []int{2, 2}
|
|
t1, err := gotensor.NewTensor(t1_data, t1_shape)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
t2_data := []float64{5, 6, 7, 8}
|
|
t2, err := gotensor.NewTensor(t2_data, t1_shape)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
fmt.Printf("张量1:\n%s\n", t1.String())
|
|
fmt.Printf("张量2:\n%s\n", t2.String())
|
|
|
|
// 加法运算
|
|
add_result, err := t1.Add(t2)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Printf("加法结果 (t1 + t2):\n%s\n", add_result.String())
|
|
|
|
// 减法运算
|
|
sub_result, err := t1.Subtract(t2)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Printf("减法结果 (t1 - t2):\n%s\n", sub_result.String())
|
|
|
|
// 逐元素乘法
|
|
mul_result, err := t1.Multiply(t2)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Printf("逐元素乘法结果 (t1 * t2):\n%s\n", mul_result.String())
|
|
|
|
// 数乘
|
|
scale_result := t1.Scale(2.0)
|
|
fmt.Printf("数乘结果 (t1 * 2):\n%s\n", scale_result.String())
|
|
|
|
// 矩阵乘法
|
|
matmul_result, err := t1.MatMul(t2)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Printf("矩阵乘法结果 (t1 @ t2):\n%s\n", matmul_result.String())
|
|
|
|
// 创建零张量和单位矩阵
|
|
zeros, err := gotensor.NewZeros([]int{2, 3})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Printf("2x3零张量:\n%s\n", zeros.String())
|
|
|
|
identity, err := gotensor.NewIdentity(3)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Printf("3x3单位矩阵:\n%s\n", identity.String())
|
|
} |