分布式机器学习中的模型聚合
我接下来的这篇论文在联邦学习(distributed learning)的上下文中引入了多任务学习,这意味着每个客户端/任务节点的训练数据分布是不同的,这样每个任务节点可以学习不同的模型,并且每个任务节点和全局模型由多个组件模型集成。本文的关键和核心在于各个任务节点所学习的模型的聚合/通信。根据模型聚合方式的不同,模型中使用的算法可以分为客户端-服务器方法和完全分散方法(实际上没有其他聚合方法,比如另一篇论文中提出的聚类方法,我们这里暂时跳过),这两种方法都可以在具体实现中通过优化代理损失函数来替代。
论文[1]在联邦(分布式)学习的背景下引入了多任务学习,这意味着每个客户端/任务节点的训练数据分布是不同的,从而每个任务节点学习不同的模型,每个任务节点和全局模型由多个组件模型集成。本文的关键和核心在于各个任务节点所学习的模型的聚合/通信。根据模型的聚合方式不同,模型中使用的算法可以分为客户端-服务器方法和完全分散方法(实际上没有其他聚合方法,如论文[3]中提出的聚类聚合方法,代码如[4]所示,这里我们暂时跳过)。
因为要实现的任务aggregator有很多种,论文代码采取的措施(Github上开源,见[2])是先实现aggregator的抽象基类,实现一些通用的方法,并指定抽象方法的接口,然后具体的任务aggregator类继承抽象基类,再做具体的实现。
让我们首先看看任务聚合器的抽象基类。
类别聚合器(ABC):
r ' '是聚合器的基类。“聚合器”指定客户端“”之间的通信
def __init__(
自我,
客户,
global _ learners _ ensemble,
log_freq,
global_train_logger,
全局_测试_记录器,
sampling _ rate=1。
样本替换=假,
测试客户端=无,
verbose=0,
种子=无,
*参数,
**kwargs
):
rng_seed=(seed if (seed不是None,seed=0) else int(time.time()))
Self.rng=random.random (RNG种子)#随机数生成器
self . NP _ RNG=NP . random . default _ RNG(RNG _ seed)# numpy随机数生成器
如果test_clients为None:
test_clients=[]
客户端=客户端#列表[客户端]
自测客户端=测试客户端#列表[客户端]
self . global _ leaders _ ensemble=global _ leaders _ ensemble # List[leaderer]
self . device=self . global _ leaders _ ensemble . device
self.log_freq=log_freq
self.verbose=verbose
# verbose:调整输出打印的冗余度,
# ` 0 '表示安静(没有任何打印输出),` 1 '显示日志,` 2 '显示所有本地日志;默认值为“0”。
self . global _ train _ logger=global _ train _ logger
self . global _ test _ logger=global _ test _ logger
self . model _ dim=self . global _ leaders _ aggregate . model _ dim # #模型特征维度
self.n_client
s = len(clients)
self.n_test_clients = len(test_clients)
self.n_learners = len(self.global_learners_ensemble)
# 存储为每个client分配的权重(权重为0-1之间的小数)
self.clients_weights =\
torch.tensor(
[client.n_train_samples for client in self.clients],
dtype=torch.float32
)
self.clients_weights = self.clients_weights / self.clients_weights.sum()
self.sampling_rate = sampling_rate # clients在每一轮使用的比例,默认为`1.`
self.sample_with_replacement = sample_with_replacement #对client进行采用是可重复还是无重复的,with_replacement=True表示可重复的,否则是不可重复的
# 每轮迭代需要使用到的client个数
self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))
# 采样得到的client列表
self.sampled_clients = list()
# 记载当前的迭代通信轮数
self.c_round = 0
self.write_logs()
@abstractmethod
def mix(self):
"""
该方法用于完成各client之间的权重参数与通信操作
"""
pass
@abstractmethod
def update_clients(self):
"""
该方法用于将所有全局分量模型拷贝到各个client,相当于boardcast操作
"""
pass
def update_test_clients(self):
"""
将全局(gobal)的所有分量模型都拷贝到各个client上
"""
def write_logs(self):
"""
对全局(global)的train和test数据集的loss和acc做记录
需要对所有client的所有样本做累加,然后除以所有client的样本总数做平均。
"""
def save_state(self, dir_path):
"""
保存aggregator的模型state,。例如, `global_learners_ensemble`中每个分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每个client的 `learners_weights` (注意,这个权重不是模型内部的参数,而是进行继承的时候对各个分量模型赋予的权重,包含train和test两部分,以一个大小为n_clients(n_test_clients)× n_learners的numpy数组的格式,即`.npy` 文件)。
"""
def load_state(self, dir_path):
"""
加载aggregator的模型state,即save_state方法里保存的那些
"""
def sample_clients(self):
"""
对clients进行采样,
如果self.sample_with_replacement为True,则为可重复采样,
否则,则为不可重复采用。
最终得到一个clients子集列表并赋予self.sampled_clients
"""
1.client-server 算法
这种方式的通信/聚合方法也称中心化(centralized)方法,因为该方法在每一轮迭代最后将所有client的权重数据汇集到server节点。这种方法的优化迭代部分的伪代码示意如下:
落实到具体代码实现上,这种方法的Aggregator设计如下:
class CentralizedAggregator(Aggregator):
r""" 标准的中心化Aggreagator
所有clients在每一轮迭代末和average client完全同步.
"""
def mix(self):
self.sample_clients()
# 对self.sampled_clients中每个client的参数进行优化
for client in self.sampled_clients:
# 相当于伪代码第11行调用的LocalSolver函数
client.step()
# 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)
# 相当于伪代码第13行
for learner_id, learner in enumerate(self.global_learners_ensemble):
# 获取所有client中对应learner_id的分量模型
learners = [client.learners_ensemble[learner_id] for client in self.clients]
# global模型的分量模型为所有client对应分量模型取平均,相当于伪代码第14行
average_learners(learners, learner, weights=self.clients_weights)
# 将更新后的模型赋予所有clients,相当于伪代码第5行的boardcast操作
self.update_clients()
# 通信轮数+1
self.c_round += 1
if self.c_round % self.log_freq == 0:
self.write_logs()
def update_clients(self):
"""
此函数负责将所有全局分量模型拷贝到各个client,相当于伪代码中第5行的boardcast操作
"""
for client in self.clients:
for learner_id, learner in enumerate(client.learners_ensemble):
copy_model(learner.model, self.global_learners_ensemble[learner_id].model)
if callable(getattr(learner.optimizer, "set_initial_params", None)):
learner.optimizer.set_initial_params(
self.global_learners_ensemble[learner_id].model.parameters()
)
2. fully decentralized(完全去中心化)算法
这种方法之所以被称为去中心化的,因为该方法在每一轮迭代不需要所有client的权重数据汇集到一个特定的server节点,而只需要完成每个节点和其邻居进行通信(参数共享)即可。这种方法的优化迭代部分的伪代码示意如下:
落实到具体代码实现上,这种方法的Aggregator设计如下:
class DecentralizedAggregator(Aggregator):
def __init__(
self,
clients,
global_learners_ensemble,
mixing_matrix,
log_freq,
global_train_logger,
global_test_logger,
sampling_rate=1.,
sample_with_replacement=True,
test_clients=None,
verbose=0,
seed=None):
super(DecentralizedAggregator, self).__init__(
clients=clients,
global_learners_ensemble=global_learners_ensemble,
log_freq=log_freq,
global_train_logger=global_train_logger,
global_test_logger=global_test_logger,
sampling_rate=sampling_rate,
sample_with_replacement=sample_with_replacement,
test_clients=test_clients,
verbose=verbose,
seed=seed
)
self.mixing_matrix = mixing_matrix
assert self.sampling_rate = 1, "partial sampling is not supported with DecentralizedAggregator"
def update_clients(self):
pass
def mix(self):
# 对各clients的模型参数进行优化
for client in self.clients:
client.step()
# 存储每个模型各参数混合的权重
# 行对应不同的client,列对应单个模型中不同的参数
# (注意:每个分量有独立的mixing_matrix)
mixing_matrix = torch.tensor(
self.mixing_matrix.copy(),
dtype=torch.float32,
device=self.device
)
# 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)
# 相当于伪代码第14行
for learner_id, global_learner in enumerate(self.global_learners_ensemble):
# 用于将指定learner_id的各client的模型state读出暂存
state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]
# 遍历global模型中的各参数
for key, param in global_learner.model.state_dict().items():
shape_ = param.shape
models_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)
for ii, sd in enumerate(state_dicts):
# models_params的第ii个下标存储的是第ii个client的参数
models_params[ii] = sd[key].view(1, -1)
# models_params的每一行是一个client的参数
# @符号表示矩阵乘/矩阵向量乘
# 故这里表示每个client参数是其他所有client参数的混合
models_params = mixing_matrix @ models_params
for ii, sd in enumerate(state_dicts):
# 将第ii个client的参数存入state_dicts中对应位置
sd[key] = models_params[ii].view(shape_)
# 将更新好的参数从state_dicts存入各client节点的模型中
for client_id, client in enumerate(self.clients):
client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])
# 通信轮数+1
self.c_round += 1
if self.c_round % self.log_freq == 0:
self.write_logs()
参考文献
- [1] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.
- [2] https://github.com/omarfoq/FedEM
- [3] Clustered Federated Learning: Model-Agnostic Distributed Multi-Task Optimization under Privacy Constraints
- [4] https://github.com/felisat/clustered-federated-learning
数学是符号的艺术,音乐是上界的语言。
内容来源网络,如有侵权,联系删除,本文地址:https://www.230890.com/zhan/132063.html