golang源码分析-深入理解WaitGroup

WaitGroup介绍

实际开发中经常会碰到一种情况:一个聚合信息的接口(如app首页/用户信息等),需要查询一大堆数据,而这些数据都是分散在各个服务中。同步去一个一个查询会导致接口响应过长。于是很自然而然想到了并发查询。

并发查询

针对这种场景,并发查询就需要做好并发控制,确保所有的查询都完成后再返回结果。这时候就需要用到sync.WaitGroup。(主角登场)

事实上,WaitGroup这个并发原语特别常见,在linux中就有类似的barrier机制。可见这种需要一个线程等待一组线程完成的场景是非常常见的。所以本文就来深入理解一下WaitGroup的应用与实现。

基本使用

本文的golang源码基于go1.22

既然要学习一个东西的实现原理,那么肯定需要先知道它是怎么使用的,先看看WaitGroup对外暴露了哪些方法:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
$ go doc sync.Waitgroup
package sync // import "sync"

type WaitGroup struct {
        // Has unexported fields.
}
    A WaitGroup waits for a collection of goroutines to finish. The main
    goroutine calls Add to set the number of goroutines to wait for. Then each
    of the goroutines runs and calls Done when finished. At the same time,
    Wait can be used to block until all goroutines have finished.

    A WaitGroup must not be copied after first use.

    In the terminology of the Go memory model, a call to Done “synchronizes
    before” the return of any Wait call that it unblocks.

func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()

上面的英文注释已经很清晰了,WaitGroup是用来等待一组goroutine完成的。主goroutine调用Add方法设置等待的goroutine数量,然后每个goroutine运行并在完成后调用Done。同时,Wait可以用来在主goroutine进行阻塞,直到所有goroutine完成。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
type Calculator struct {
	Cnt int
	m   sync.Mutex
}

func (c *Calculator) Inc() {
	c.m.Lock()
	defer c.m.Unlock()
	c.Cnt++
}

func main() {
    // 先初始化一个WaitGroup
	var wg sync.WaitGroup
	cal := &Calculator{m: sync.Mutex{}}

	for i := 0; i < 10; i++ {
        // 每次开一个协程前先加一个需要等待的goroutine数量
		wg.Add(1)
		go func() {
            // 协程逻辑完成后调用Done, 表示该协程完成
			defer wg.Done()
            // 协程具体的业务逻辑
			cal.Inc()
		}()
	}

    // 主协程调用Wait, 来阻塞自己,等待所有协程完成
	wg.Wait()
    // output: 10
	fmt.Println("res:", cal.Cnt)
}

这里写了个demo, 用来展示WaitGroup的基本使用。具体的用法在代码中的注释里了。需要注意的是,通过Add方法设置等待者数量时,可以在每次开协程前调用(每次开之前加一),也可以在循环外部调用(一次性加完)。

源码实现

结构体定义

WaitGroup的结构体定义如下:

1
2
3
4
5
6
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

第一个字段noCopy是一个实现sync.Locker接口的空结构体。目的只是为了在WaitGroup被复制时,可以通过go vet检测出来。各位以后编程也可以使用这个技巧,避免不小心复制了一个不应该复制的对象(通过实现Locker接口)。

第二个字段是关键了,state是一个复合型字段,类似于sync.Mutex.state,一个字段有多种含义。该字段总共有16个字节,前8个字节(高32位)用来存储协程数量,后8个字节(低32位)用来存储等待者数量。

第三个字段sema是信号量,用来阻塞/唤醒主协程。

Add

这里去除了一些竞态检测的代码,因为那不是本文重点,只保留关键逻辑:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func (wg *WaitGroup) Add(delta int) {
	state := wg.state.Add(uint64(delta) << 32)
	v := int32(state >> 32)
	w := uint32(state)

	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		return
	}
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
	if wg.state.Load() != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	wg.state.Store(0)
	for ; w != 0; w-- {
		runtime_Semrelease(&wg.sema, false, 0)
	}
}

接下来逐行解释一下:

  1. 先通过state.Add方法将协程数量加上。至于为什么要左移32位,是因为state的高32位存储协程数量,低32位存储等待者数量(上文说过)。
  2. v是协程数量(取state右移32位的结果),w是等待者数量(取state的低32位)。
  3. 判断协程数量是否为负数,如果是则panic。这个判断比较好理解,因为如果协程数量为负了,那么主协程还在阻塞啥呢,不就死锁了吗。
  4. 如果等待者不为0,且协程数量等于delta,说明Add方法和Wait方法同时调用了(下文会说明Wait方法会操作等待者数量),panic。为啥要加这个检测呢?是因为要确保Wait调用必须要在所有Add之后,不然可能会发生类似于"死锁"的情况。
  5. 如果协程数量大于0或者等待者数量为0,直接返回。这里的逻辑是,如果协程数量大于0,说明还有协程在运行,不需要唤醒等待者;如果等待者数量为0,说明没有等待者,也不需要唤醒。
  6. 程序运行到这里,说明此时协程数量为0并且等待者数量不为0。这时候再做一个检测,确保Add方法和Wait方法没有并发调用。如果有并发调用,panic。具体是取出此时的state和之前的state进行比较,如果不一致,说明此时Wait方法被调用了。
  7. 设置state为0,同时唤醒所有等待者。

总体来看,Add方法的逻辑比较清晰。大致就是给高32位加上你设置的值,然后判断是否直接返回,如果协程数量为0那就唤醒等待者。

Done

1
2
3
4
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Done方法就是调用Add(-1),减少协程数量。为啥减一呢?因为这个Done方法是每个协程处理完逻辑后调用的,所以减一。

刚刚看Add实现时,如果协程数量为0就会去唤醒等待者。其实就是所有协程处理完逻辑后,调用Done方法,协程数量减为0,然后唤醒等待者。

于是我们猜测,调用Wait方法时会阻塞等待者,直到协程数量为0时Add方法再将其唤醒。

验证猜想,接下来看Wait方法的实现。

Wait

这里也是去除了一些竞态检测的代码,只保留关键逻辑:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
	for {
		state := wg.state.Load()
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			// Counter is 0, no need to wait.
			return
		}
		// Increment waiters count.
		if wg.state.CompareAndSwap(state, state+1) {
			runtime_Semacquire(&wg.sema)
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}
  1. 先取出state,然后取出协程数量(v)和等待者数量(w)。
  2. 如果协程数量为0,直接返回。这个逻辑很好理解,如果协程数量为0,说明所有协程都处理完了,不需要等待。
  3. CAS操作,将等待者数量加1。如果没成功,则下一次循环再次尝试,这也是有个无限for循环的原因。
  4. 阻塞当前协程
  5. 唤醒后判断当前state是否为0,如果不为0,说明WaitGroup被复用了(有其他地方在本次Wait还没结束时就调用了Add),panic。
  6. 结束

总结

一句话总结,Add方法操作state的高32位。Wait方法操作state的低32位。Done方法就是Add(-1)。最后再加上一些检测逻辑。

updatedupdated2024-06-272024-06-27