一致性哈希算法学习笔记

一致性哈希算法(Consistent Hashing)学习笔记。

主要内容从参考资料中摘抄,版权归原作者所有。

参考资料:

分布式缓存问题

假设我们有一个网站,最近发现随着流量增加,服务器压力越来越大,之前直接读写数据库的方式不太给力了,于是我们想引入 Memcached 作为缓存机制。现在我们一共有三台机器可以作为 Memcached 服务器,如下图所示。

很显然,最简单的策略是将每一次 Memcached 请求随机发送到一台 Memcached 服务器,但是这种策略可能会带来两个问题:

  1. 同一份数据可能被存在不同的机器上而造成数据冗余;
  2. 有可能某数据已经被缓存但是访问却没有命中,因为无法保证对相同 key 的访问都被发送到相同的服务器。

因此,随机策略无论是时间效率还是空间效率都非常不好。

要解决上述问题需要做到如下一点:保证对相同 key 的访问会被发送到相同的服务器。很多方法可以实现这一点,最常用的方法是计算哈希。例如对于每次访问,可以按如下算法计算哈希值:

h = Hash(key) % 3

其中,Hash 是一个从字符串到正整数的哈希映射函数。这样,如果我们将 Memcached Server 分别编号为 0、1、2,那么就可以根据上述算式和 key 计算出服务器编号 h,然后去访问。

这个方法虽然解决了上面提到的两个问题,但是存在一些其他的问题,如果将上述方法抽象:

h = Hash(key) % N

这个算式计算每个 key 的请求应该被发送到哪台服务器,其中 N 为服务器的数量,并且服务器按照 0..(N-1) 进行编号。

这个算法的问题在于容错性和扩展性不好。所谓容错性是指当系统中某一个或几个服务器变得不可用时,整个系统是否可以正确高效运行;而扩展性是指当加入新的服务器后,整个系统是否可以正确高效运行。

现在假设有一台服务器宕机了,那么为了填补空缺,要将宕机的服务器从编号列表中移除,后面的服务器按顺序前移一位并将其编号值减一,此时每个 key 就要按 h = Hash(key) % (N-1) 重新计算;同样,如果新增了一台服务器,虽然原有服务器编号不用改变,但是要按 h = Hash(key) % (N+1) 重新计算哈希值。因此系统中一旦有服务器变更,大量的 key 会被重定位到不同有服务器从而造成大量的缓存不命中。而这种情况在分布式系统中是非常糟糕的。

一个设计良好的分布式哈希方案应该具有良好的单调性,即服务节点的增减不会造成大量哈希值重定位。一致性哈希算法就是这样一种哈希方案。

一致性哈希算法

算法简述

一致性哈希算法(Consistent Hashing)最早在论文《Consistent Hashing and Random Trees: Distributed Caching Protocols for Relieving Hot Spots on the World Wide Web》中被提出。简单来说,一致性哈希将整个哈希值空间组织成一个虚拟的圆环,如假设某哈希函数 H 的值空间为 0..232-1(即哈希值是一个32位无符号整数),整个哈希空间环如下图:

整个空间按顺时针方向组织。0 和 232-1 在零点钟方向重合。

下一步将各个服务器使用哈希函数 H 计算一个哈希,具体可以选择服务器的IP或主机名作为关键字进行哈希,这样每台机器就能确定其在哈希环上的位置,这里假设将上文中三台服务器使用IP地址哈希后再环空间的位置如下:

接下来使用如下算法定位数据访问到相应服务器:将数据 key 使用相同的哈希函数 H 计算出哈希值 h,根据 h 确定此数据在环上的位置,从此位置沿环顺时针“行走”,第一台遇到的服务器就是其应该定位到的服务器。

例如我们有 A、B、C、D 四个数据对象,经过哈希计算后,在环空间的位置如下:

根据一致性哈希算法,数据 A 会被定位到 Server 1 上,D 被定为到 Server 2 上,而 B、C 分别被定为到 Server 2 上。

容错性与可扩展性分析

下面分析一致性哈希算法的容错性和可扩展性。现假设 Server 3 宕机了:

可以看到此时 A、B、C 不会受到影响,只有数据 D 被重定位到 Serverr 2。一般的,在一致性哈希算法中,如果一台服务器不可用,则受影响的数据仅仅是此服务器到其环空间中前一台服务器(即沿逆时针方向行走遇到的第一台服务器)之间的数据,其他不会受到影响。

考虑可扩展性,如果我们在系统中增加一台服务器 Server 4:

此时 A、C、D 不受影响,只有 B 需要重定位到新的 Server 4。一般的,在一致性哈希算法中,如果增加一台服务器,则受影响的数据仅仅是新服务器到其环空间中前一台服务器(即沿逆时针方向行走遇到的第一台服务器)之间的数据,其他不会受到影响。

综上所述,一致性哈希算法对于节点的增减都只需要定位环空间中的一小部分数据,具有较好的容错性和可扩展性。

虚拟节点

一致性哈希算法在服务节点太少时,容易因为节点分布不均匀而造成数据倾斜问题。假如我们系统中有两台服务器,其环分布如下:

此时必然造成大量数据集中到 Server 1 上,而只有极少量数据定位到 Server 2 上。为了解决这种数据倾斜问题,一致性哈希算法引入了虚拟节点机制,即对每一个服务器计算多个哈希,每个计算结果位置都放置一个此服务节点,称为虚拟节点。

具体做法可以在服务器IP或主机名的后面增加编号来实现。例如上面的情况,我们决定为每台服务器计算三个虚拟节点,于是可以分别计算 "Memcached Server 1#1""Memcached Server 1#2""Memcached Server 1#3""Memcached Server 2#1""Memcached Server 2#2""Memcached Server 2#3" 的哈希值,于是形成六个虚拟节点:

同时数据定位算法不变,只是多了一步虚拟节点到实际节点的映射,例如定位到 "Memcached Server 1#1""Memcached Server 1#2""Memcached Server 1#3" 三个虚拟节点的数据均定位到 Server 1 上。这样就解决了服务节点少时数据倾斜的问题。在实际应用中,通常将虚拟节点数设置为32甚至更大,因此即使很少的服务节点也能做到相对均匀的数据分布。

相关分布式问题

节点权重

不同的节点处理能力可能不一致,处理能力强大的服务节点可以划分多一些虚拟节点,相应的处理能力较差的服务节点可以划分少一些虚拟节点。

在 OpenStack Swift 中引入了权重(Weight)的概念来做这件事情。

数据副本(Replica)

为了保证数据安全,分布是系统通常会使用冗余副本(Replica)来保证数据安全。

NWR 是一种在分布式存储系统中用于控制一致性级别的一种策略。每个字母的含义如下:

  • N: 同一份数据的 Replica 的份数
  • W: 更新一个数据对象的时候需要确保成功更新的份数
  • R: 读取一个数据需要读取的 Replica 的份数

在健壮的分布式系统中,数据的单点是不允许存在的。即线上正常存在的 Replica 数量是1的情况是非常危险的,因为一旦这个 Replica 发生错误,就可能发生数据的永久性错误。加入我们把 N 设置成为2,那么只要有一个存储节点发生损坏,就会有单点的存在,所以 N 必须大于2。N 越大,系统的维护和整体成本就越高,工业界通常把副本数量 N 设置为3。

分区(Zone)

考虑 CAP 定理中的分区容错性(P),分布式系统需要一种机制对服务器的物理资源进行隔离。

OpenStack Swift 中引入了 Zone 的概念对服务器进行物理隔离。所有的服务节点被分割到不同的 Zone 中,每个虚拟节点(Partition)的副本(Replica)不能放在同一个 Zone 内。

总结

目前一致性哈希基本成为了分布式系统组建的标准配置,例如 Memcached 的各种客户端都提供内置的一致性哈希支持。在企业级产品中应用非常广泛,如 Amazon DynamoDB,OpenStack Swift 等。

抄一段测试代码:

In [1]:
# http://www.cnblogs.com/yuxc/archive/2012/06/22/2558312.html
from array import array
from hashlib import md5
from random import shuffle
from struct import unpack_from
from time import time

class Ring(object):
    def __init__(self, nodes, part2node, relicas):
        self.nodes = nodes
        self.part2node = part2node
        self.replicas = replicas
        partition_power = 1
        while 2 ** partition_power < len(part2node):
            partition_power += 1
        if len(part2node) != 2 ** partition_power:
            raise ValueError('part2node length is not an exact power of 2')
        self.partition_shift = 32 - partition_power

    def get_nodes(self, data_id):
        data_id = bytes(str(data_id), 'utf8')
        part = unpack_from('>I', md5(data_id).digest()
                          )[0] >> self.partition_shift
        node_ids = [self.part2node[part]]
        zones = [self.nodes[node_ids[0]]]
        for replica in range(1, self.replicas):
            while (self.part2node[part] in node_ids and
                    self.nodes[self.part2node[part]] in zones):
                part += 1
                if part >= len(self.part2node):
                    part = 0
            node_ids.append(self.part2node[part])
            zones.append(self.nodes[node_ids[-1]])
        return [self.nodes[n] for n in node_ids]

def build_ring(nodes, partition_power, replicas):
    begin = time()
    parts = 2 ** partition_power
    total_weight = float(sum(n['weight'] for n in nodes.values()))
    for node in nodes.values():
        node['desired_parts'] = parts / total_weight * node['weight']
    part2node = array('H')
    for part in range(2 ** partition_power):
        for node in nodes.values():
            if node['desired_parts'] >= 1:
                node['desired_parts'] -= 1
                part2node.append(node['id'])
                break
        else:
            for node in nodes.values():
                if node['desired_parts'] >= 0:
                    node['desired_parts'] -= 1
                    part2node.append(node['id'])
                    break
    shuffle(part2node)
    ring = Ring(nodes, part2node, replicas)
    print('%.02fs to build ring' % (time() - begin))
    return ring

def test_ring(ring, replicas):
    begin = time()
    data_id_count = 10000000
    node_counts = {}
    zone_counts = {}
    for data_id in range(data_id_count):
        for node in ring.get_nodes(data_id):
            node_counts[node['id']] = (
                node_counts.get(node['id'], 0) + 1)
            zone_counts[node['zone']] = (
                zone_counts.get(node['zone'], 0) + 1)
    print('%.02fs to test ring' % (time() - begin))
    
    total_weight = float(sum(n['weight']
                             for n in ring.nodes.values()))
    max_over = 0
    max_under = 0
    for node in ring.nodes.values():
        desired = data_id_count * replicas * node['weight'] / total_weight
        diff = node_counts[node['id']] - desired
        if diff > 0:
            over = 100.0 * diff / desired
            if over > max_over:
                max_over = over
        else:
            under = - 100.0 * diff / desired
            if under > max_under:
                max_under = under
    print('%.02f%% max node over' % max_over)
    print('%.02f%% max node under' % max_under)
    
    max_over = 0
    max_under = 0
    for zone in set(n['zone'] for n in ring.nodes.values()):
        zone_weight = sum(n['weight']
                          for n in ring.nodes.values()
                          if n['zone'] == zone)
        desired = data_id_count * replicas * zone_weight / total_weight
        diff = zone_counts[zone] - desired
        if diff > 0:
            over = 100.0 * diff / desired
            if over > max_over:
                max_over = over
        else:
            under = - 100.0 * diff / desired
            if under > max_under:
                max_under = under
    print('%.02f%% max zone over' % max_over)
    print('%.02f%% max zone under' % max_under)


partition_power = 16
replicas = 3
node_count = 256
zone_count = 16
nodes = {}
while len(nodes) < node_count:
    zone = 0
    while zone < zone_count and len(nodes) < node_count:
        node_id = len(nodes)
        nodes[node_id] = {'id': node_id, 'zone': zone,
                          'weight': 1.0 + (node_id % 2)}
        zone += 1
ring = build_ring(nodes, partition_power, replicas)
test_ring(ring, replicas)

1.15s to build ring
81.58s to test ring
1.38% max node over
1.67% max node under
0.23% max zone over
0.21% max zone under