def get_mean_std_value(loader):
'''
求数据集的均值和标准差
:param loader:
:return:
'''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
# data: [batch_size,channels,height,width]
# 计算 dim=0,2,3 维度的均值和,dim=1 为通道数量,不用参与计算
data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width]
# 计算 dim=0,2,3 维度的平方均值和,dim=1 为通道数量,不用参与计算
data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width]
# 统计 batch 的数量
num_batches += 1
# 计算均值
mean = data_sum/num_batches
# 计算标准差
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
为什么可以这样计算均值,从这个代码中我的到一个结论:"每个样本均值的和/样本数=整体数据的均值"
有点不太理解这个东西,有大佬能用数学公式证明一下吗
简单说明一下数据情况:这是 CIFAR10 数据集,每个样本的结构是( batch_size,channels,height,width), 即(样本数量,RGB 通道,图片高度,图片宽度)
这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.