爱折腾的WaitGroup

WaitGroup是Go并发编程中经常使用的做任务编排的一个一个并发原语。看起来它只有几个简单的方法,使用起来比较简单。实际上,WaitGroup的内部实现也陆陆续续改变了好几次,主要是针对它的字段的原子操作不断的做优化。

WaitGroup原始的实现

最早的WaitGroup的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
type WaitGroup struct {
m Mutex
counter int32
waiters int32
sema *uint32
}
func (wg *WaitGroup) Add(delta int) {
v := atomic.AddInt32(&wg.counter, int32(delta))
if v < 0 {
panic("sync: negative WaitGroup count")
}
if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 {
return
}
wg.m.Lock()
for i := int32(0); i < wg.waiters; i++ {
runtime_Semrelease(wg.sema)
}
wg.waiters = 0
wg.sema = nil
wg.m.Unlock()
}

它的实现字段的意义比较明确,但是实现还略显粗糙,比如sema采用指针实现。

之后将字段counterwaiters合并。为了要保证64bit的原子操作8位对齐, 需要找到state1的对齐点。 sema去掉了指针实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
type WaitGroup struct {
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers do not ensure it. So we allocate 12 bytes and then use
// the aligned 8 bytes in them as state.
state1 [12]byte
sema uint32
}
func (wg *WaitGroup) state() *uint64 {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1))
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[4]))
}
}

后来,WaitGroup实现如下,并稳定下来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers do not ensure it. So we allocate 12 bytes and then use
// the aligned 8 bytes in them as state, and the other 4 as storage
// for the sema.
state1 [3]uint32
}
// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}

state1 和 sema字段合并成一个字段state1, 这个数组是uint32,四字节。所以要么是第一个元素就是8byte对齐的,要么就是第二个元素是8byte对齐的。找到对齐的8byte,剩余的4byte就作为sema。

这个实现没有问题,就是有些饶人。因为你不得不检查state1的对齐,才能确定哪个是counter和waiters,哪个是sema。

问个问题: WaitGroup的waiter数最多是多大?

Go 1.18的改变

在Go 1.18中, WaitGroup又做了改变,针对64bit架构的环境,编译器保证伟uint64类型的字段按照8byte对齐。

1
2
3
4
5
6
7
8
9
10
11
12
type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers only guarantee that 64-bit fields are 32-bit aligned.
// For this reason on 32 bit architectures we need to check in state()
// if state1 is aligned or not, and dynamically "swap" the field order if
// needed.
state1 uint64
state2 uint32
}

当然为了兼容32bit的架构,还是需要判断一下对齐:

1
2
3
4
5
6
7
8
9
10
11
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
// state1 is 64-bit aligned: nothing to do.
return &wg.state1, &wg.state2
} else {
// state1 is 32-bit aligned but not 64-bit aligned: this means that
// (&state1)+4 is 64-bit aligned.
state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
}
}

总体上来说,在linux/amd64环境中,此修改会带来 9%~30%的性能提升。

Go 1.20中的改变

优化还未万。在Go 1.19中, Russ Cox实现了atomic.Uint64,它在64bit架构和32bit架构下都是8byte对齐的,为啥呢?因为它有一个"尚方宝剑":align64

1
2
3
4
5
6
// An Uint64 is an atomic uint64. The zero value is zero.
type Uint64 struct {
_ noCopy
_ align64
v uint64
}

64bit架构下没有问题,32bit架构下看到这个字段,Go编译器就会自动把它按照8byte对齐,这是一个约定。你在你的package下定义struct加上align64是没有用的。
不过如果你也想让你的struct 8byte对齐的话,你可以使用下面的技术:

1
2
3
4
5
6
import "sync/atomic"
type T struct {
_ [0]atomic.Int64 // 占用0字节,但是隐含字段是8byte对齐的
x uint64 // x是8byte对齐的
}

这样依赖, WaitGroup的实现又可以简化成了:

1
2
3
4
5
6
type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}

也不必实现单独的state()方法了。直接使用state字段即可(去除了race代码):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func (wg *WaitGroup) Add(delta int) {
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}