先看Code:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
让我们逐步解析这段代码的每个部分。
1. 设备选择
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
这部分代码用于选择计算设备(GPU、TPU等)来运行 PyTorch 模型。
-
torch.accelerator.is_available():
检查是否有可用的加速设备(如 GPU、TPU 等)。如果可用,则返回True,否则返回False。 -
torch.accelerator.current_accelerator().type:
获取当前加速设备的类型,例如:"cuda"代表 NVIDIA GPU(通常是 CUDA 设备)"mps"代表 Mac 上的 Metal 设备(Apple GPU)"xpu"代表 Intel GPU"tpu"代表 Google TPU
-
如果没有加速设备,则回退到
"cpu"。 -
print(f"Using {device} device")
输出当前使用的设备,例如:Using cuda device
注意:
torch.accelerator 是 PyTorch 2.1 及以上版本引入的更通用的加速接口。如果你使用的是早期版本,可能需要替换为:
device = "cuda" if torch.cuda.is_available() else "cpu"
2. 定义神经网络
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
-
class NeuralNetwork(nn.Module)
这里定义了一个神经网络类,它继承自torch.nn.Module,这使得它可以作为 PyTorch 模型。 -
super().__init__()
调用nn.Module的构造函数,确保父类正确初始化。 -
self.flatten = nn.Flatten()- 这个层会将 2D 图像(28×28 像素)展平为 1D 向量(长度 784)。
- 例如,一个形状为
(batch_size, 1, 28, 28)的输入,经过nn.Flatten()处理后,会变为(batch_size, 784)。
-
self.linear_relu_stack = nn.Sequential(...)
这是一个 顺序模型(nn.Sequential),包括以下层:nn.Linear(28*28, 512):
线性层,将 784 维输入转换为 512 维隐藏层。nn.ReLU():
采用 ReLU 激活函数,引入非线性。nn.Linear(512, 512):
第二个隐藏层,继续处理 512 维特征。nn.ReLU():
再次应用 ReLU 激活。nn.Linear(512, 10):
输出层,将特征转换为 10 维(对应于 10 个类别,例如手写数字 0-9)。
3. 前向传播
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
-
x = self.flatten(x):
将输入的 2D 图像展平成 1D。 -
logits = self.linear_relu_stack(x):
经过linear_relu_stack处理,最终得到logits(未归一化的类别分数)。 -
return logits:
输出logits,后续通常会用nn.CrossEntropyLoss()来计算损失。
4. 创建模型并移动到指定设备
model = NeuralNetwork().to(device)
print(model)
NeuralNetwork()创建模型对象。.to(device)将模型移动到 CPU 或 GPU(如果可用)。print(model)输出模型结构,例如:NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) )
总结
- 选择计算设备(GPU/TPU/CPU)。
- 定义神经网络:
Flatten展平输入。Linear层进行特征变换。ReLU作为激活函数。
- 定义前向传播(
forward方法)。 - 创建模型并移动到设备。
这段代码通常用于MNIST 手写数字分类,但可以扩展到其他任务![]()
