有大佬对计算数据的均值和方差了解么

291 天前
 bler
def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算 dim=0,2,3 维度的均值和,dim=1 为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算 dim=0,2,3 维度的平方均值和,dim=1 为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计 batch 的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

为什么可以这样计算均值,从这个代码中我的到一个结论:"每个样本均值的和/样本数=整体数据的均值"

有点不太理解这个东西,有大佬能用数学公式证明一下吗

简单说明一下数据情况:这是 CIFAR10 数据集,每个样本的结构是( batch_size,channels,height,width), 即(样本数量,RGB 通道,图片高度,图片宽度)

1924 次点击
所在节点    程序员
8 条回复
dji38838c
291 天前
这个很显然呀。
比如:假如一共有 300 个数据(a1, a2,... a300),分成 100 组,每组 3 个。
那么 [(a1+a2+a3)/3 + (a4+a5+a6)/3 + .... (a298+a299+a300) / 3] / 100
可以整理成 [(a1+a2+a3+...+a300)/3] / 100 = (a1+a2+...+a300)/300
NessajCN
291 天前
这里能这么算的前提是每个样本的采样数量,也就是计算始终用来计算 torch.mean() 的分母,都是一样的才成立
bler
291 天前
@dji38838c 我发现这个方法还是存在很大的问题的,这种方法只适用于"总数据量/batch_size=整数"这种情况下的计算出的结果才能成立。假设最后的数据恰好是一个异常数据,那么通过这个计算方式计算出来的均值就是有极大异常的均值
Eureka0
291 天前
这个只对每个样本的样本容量都一样的情况成立,其实就是
均值=求和(样本均值 i*样本容量 i)/求和(样本容量 i)
样本容量都一样就可以约掉了
nno
291 天前
@bler 在数据量非常大的情况下,少计算几个样本并不会对最终结果有很大影响。你要求的绝对正确只会出现在理论计算场景下
Sawyerhou
291 天前
如楼上所说,数据量很大的情况下,怎么算都差不多。
faterazer
291 天前
理论和现实是有 gap 的,楼主的疑问很正常,这就是一种近似计算,当然你也可以算整个数据集的精确均值和方差(更麻烦以及更多的计算时间)。在实践中,近似计算和精确计算不会带来太大的性能差异,一般都是按方便的来。另外 CIFAR10 这样的开源数据集的均值方差都有算好的直接用就行
bler
283 天前
已经发现这个问题了,我用 chatgpt 问了一下,好多答案都是计算一个大概值,不是一个精确值

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://ex.noerr.eu.org/t/1105178

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX