本文聯邦學習的代碼引用于https://github.com/shaoxiongji/federated-learning
本篇文章相當于帶大家讀一遍聯邦學習的代碼,同時加深了大家對聯邦學習和Pytorch框架的理解。
這里想簡單介紹一下聯邦學習。
聯邦學習說白了,就是假如有 N N N個數據擁有者 F 1 , . . . , F N {F_1,...,F_N} F1?,...,FN?,他們希望使用這些數據來訓練機器學習模型,但是又各自想隱藏自己的數據不被別人所知道(隱私保護),這個過程每個用戶傳遞本地模型參數到中心服務器訓練模型 M F E D M_{FED} MFED?,該過程中任何數據擁有者 F i F_i Fi?都不會暴露其數據 D i D_i Di?給其他人。而傳統的方法將所有的數據放到一起(中心服務器)并使用 D = D 1 ∪ D 2 . . . D N D=D_1 \cup D_2...D_N D=D1?∪D2?...DN?訓練模型 M S U M M_{SUM} MSUM?,但是在傳統方法的過程中,中心服務器會得知所有用戶的數據,故有了聯邦學習這個概念,并由此衍生出了針對聯邦學習的攻擊與防御等。
在這里,我們對比模型 M F E D M_{FED} MFED?和模型 M S U M M_{SUM} MSUM?的精度 V F E D V_{FED} VFED?和 V S U M V_{SUM} VSUM?應該非常接近,如果其精度有了損失,可能會因隱私保護而得不償失了,下面的 δ \delta δ是聯邦學習算法的精度值損失 ∣ V F E D ? V S U M ∣ < δ |V_{FED}-V_{SUM}|<\delta ∣VFED??VSUM?∣<δ
整體架構(main函數)
首先,我們先從整體進行大覽整體邏輯。首先初始化全局模型,然后劃分每個用戶的本地數據集,開始訓練,由每個客戶端進行本地訓練,然后將參數傳遞給中心服務器,進行全局平均更新模型參數并將將新的參數傳遞給每個客戶端。迭代數輪,最終就訓練好了模型
if __name__ == '__main__':# parse argsargs = args_parser() # 參數解析args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') # 切換設備# load dataset and split usersif args.dataset == 'mnist': # 加載數據集trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist)dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist)# sample usersif args.iid: # 是否服從獨立同分布的劃分數據集dict_users = mnist_iid(dataset_train, args.num_users)else:dict_users = mnist_noniid(dataset_train, args.num_users)elif args.dataset == 'cifar': # 加載數據集trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])dataset_train = datasets.CIFAR10('./data/cifar', train=True, download=True, transform=trans_cifar)dataset_test = datasets.CIFAR10('./data/cifar', train=False, download=True, transform=trans_cifar)if args.iid:dict_users = cifar_iid(dataset_train, args.num_users)else:exit('Error: only consider IID setting in CIFAR10')else:exit('Error: unrecognized dataset')img_size = dataset_train[0][0].shape# build modelif args.model == 'cnn' and args.dataset == 'cifar': # 選定模型net_glob = CNNCifar(args=args).to(args.device)elif args.model == 'cnn' and args.dataset == 'mnist':net_glob = CNNMnist(args=args).to(args.device)elif args.model == 'mlp':len_in = 1for x in img_size:len_in *= xnet_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)else:exit('Error: unrecognized model')print(net_glob)net_glob.train()# copy weightsw_glob = net_glob.state_dict() # 獲取全局權重# trainingloss_train = []cv_loss, cv_acc = [], []val_loss_pre, counter = 0, 0net_best = Nonebest_loss = Noneval_acc_list, net_list = [], []if args.all_clients: print("Aggregation over all clients") # 所有客戶端的聚合w_locals = [w_glob for i in range(args.num_users)] # 將初始化權重分配給每個用戶for iter in range(args.epochs): # 總的訓練輪次loss_locals = [] #if not args.all_clients:w_locals = []''' args.frac每次梯度下降的比例 args.num_users客戶端數量 '''m = max(int(args.frac * args.num_users), 1) # 選擇需要進行梯度下降的用戶數量idxs_users = np.random.choice(range(args.num_users), m, replace=False) # 隨機選擇for idx in idxs_users:local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) # 獲取全局模型并開始本地訓練if args.all_clients:w_locals[idx] = copy.deepcopy(w) # 獲取每個客戶端的本地參數else:w_locals.append(copy.deepcopy(w))loss_locals.append(copy.deepcopy(loss)) # 獲取損失函數# update global weightsw_glob = FedAvg(w_locals) # 進行聚合平均# copy weight to net_globnet_glob.load_state_dict(w_glob) # 更新參數# print lossloss_avg = sum(loss_locals) / len(loss_locals)print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))loss_train.append(loss_avg)# plot loss curveplt.figure()plt.plot(range(len(loss_train)), loss_train)plt.ylabel('train_loss')plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))# testingnet_glob.eval()acc_train, loss_train = test_img(net_glob, dataset_train, args)acc_test, loss_test = test_img(net_glob, dataset_test, args)print("Training accuracy: {:.2f}".format(acc_train))print("Testing accuracy: {:.2f}".format(acc_test))
參數處理
這里給出我們所使用到的參數,還有一些參數在代碼中并沒有使用到
參數 | 解釋 |
---|---|
epochs | 中心服務器訓練的輪次 |
num_users | 客戶端數量 |
frac | 每次進行梯度下降的比例 |
local_ep | 本地訓練模型的輪次 |
local_bs | 本地批量大小 |
lr | 學習率 |
momentum | SGD梯度下降法的動量大小 |
model | 選用模型 |
dataset | 所用數據集 |
iid | 數據集劃分是否符合獨立同分布 |
num_classes | 模型的通道數 |
gpu | 選用模型 |
stopping_rounds | 選用模型 |
verbose | 詳細打印 |
seed | 隨機種子 |
all_clients | 聚合所有的客戶端 |
def args_parser():parser = argparse.ArgumentParser()# federated argumentsparser.add_argument('--epochs', type=int, default=10, help="rounds of training")parser.add_argument('--num_users', type=int, default=100, help="number of users: K")parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")parser.add_argument('--bs', type=int, default=128, help="test batch size")parser.add_argument('--lr', type=float, default=0.01, help="learning rate")parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")# model argumentsparser.add_argument('--model', type=str, default='mlp', help='model name')parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')parser.add_argument('--kernel_sizes', type=str, default='3,4,5',help='comma-separated kernel size to use for convolution')parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")parser.add_argument('--max_pool', type=str, default='True',help="Whether use max pooling rather than strided convolutions")# other argumentsparser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')parser.add_argument('--num_classes', type=int, default=10, help="number of classes")parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")parser.add_argument('--gpu', type=int, default=-1, help="GPU ID, -1 for CPU")parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')parser.add_argument('--verbose', action='store_true', help='verbose print')parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')args = parser.parse_args()return args
獨立同分布劃分數據集
可以設置獨立同分布還是非獨立同分布劃分數據集(其實這塊內容基本使用不到),在目前學術界大都采用Dirichlet分布或者Pathological分布。
其實這塊內容從某個方面來看,在常規階段提點無法有較大提升,所以提出了新的場景(Dirichlet分布和Pathological分布),在新的場景中可以完成顯著的提點,當然這個場景本身是沒有問題的,以及衍生的算法也沒有問題。但是我們可以學習到一個學術思路,將我們創新出的算法放置到一個新場景或許有意料之外的效果。
其實也有點先射箭后畫靶的意思了。好了,這里不再過多談論了
def mnist_iid(dataset, num_users): # 獨立同分布劃分"""Sample I.I.D. client data from MNIST dataset:param dataset::param num_users::return: dict of image index"""num_items = int(len(dataset)/num_users)dict_users, all_idxs = {}, [i for i in range(len(dataset))]for i in range(num_users): # 遍歷每個用戶dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) # 進行抽取all_idxs = list(set(all_idxs) - dict_users[i]) # 刪除已經抽取過的數據return dict_users # 返回劃分好的字典def mnist_noniid(dataset, num_users): # 非獨立同分布劃分"""Sample non-I.I.D client data from MNIST dataset:param dataset::param num_users::return:"""num_shards, num_imgs = 200, 300idx_shard = [i for i in range(num_shards)]dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}idxs = np.arange(num_shards*num_imgs)labels = dataset.train_labels.numpy()# sort labelsidxs_labels = np.vstack((idxs, labels))idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]idxs = idxs_labels[0,:]# divide and assignfor i in range(num_users):rand_set = set(np.random.choice(idx_shard, 2, replace=False))idx_shard = list(set(idx_shard) - rand_set)for rand in rand_set:dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)return dict_users
模型
MLP模型
class MLP(nn.Module):def __init__(self, dim_in, dim_hidden, dim_out):super(MLP, self).__init__()self.layer_input = nn.Linear(dim_in, dim_hidden)self.relu = nn.ReLU()self.dropout = nn.Dropout()self.layer_hidden = nn.Linear(dim_hidden, dim_out)def forward(self, x):x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])x = self.layer_input(x)x = self.dropout(x)x = self.relu(x)x = self.layer_hidden(x)return x
卷積模型
class CNNMnist(nn.Module):def __init__(self, args):super(CNNMnist, self).__init__()self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, args.num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return xclass CNNCifar(nn.Module):def __init__(self, args):super(CNNCifar, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, args.num_classes)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
本地模型訓練
每個客戶端經過本地訓練,上傳模型置中央服務器上,服務器進行聚合并將聚合后的模型下發到各個客戶端。迭代數次,一個泛化性強大的模型便訓練好了。
這里談到了泛化性,我便多說一點,與之對應的便是個性化。此時引出來了個性化聯邦學習PFL,其實所謂的PFL從技術層面上看就是取巧了,更多的強調的是個性化,那么它是否喪失了泛化性呢?說實話,其實大部分這方面論文早已喪失了泛化性,本身就是本地訓練個模型,但是將其中的某些層經過聚合,實際上不經過聚合,模型的點數也很高。其實這里已經給出了為什么PFL的點數如此之高
當然,是否存在兩者兼具的算法呢?當讓存在,只不過其它方面又存在一些問題。學術其實就是這樣,不斷打補丁
class LocalUpdate(object):def __init__(self, args, dataset=None, idxs=None):self.args = argsself.loss_func = nn.CrossEntropyLoss()self.selected_clients = []self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)def train(self, net):net.train() # 設置為訓練模式# train and updateoptimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)epoch_loss = []for iter in range(self.args.local_ep): # 本地訓練的輪次batch_loss = []for batch_idx, (images, labels) in enumerate(self.ldr_train):images, labels = images.to(self.args.device), labels.to(self.args.device)net.zero_grad() # 梯度清零log_probs = net(images) # 預測loss = self.loss_func(log_probs, labels) # 計算損失函數loss.backward() # 反向傳播optimizer.step() # 進行優化if self.args.verbose and batch_idx % 10 == 0: # 詳細打印程度print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(iter, batch_idx * len(images), len(self.ldr_train.dataset),100. * batch_idx / len(self.ldr_train), loss.item()))batch_loss.append(loss.item())epoch_loss.append(sum(batch_loss)/len(batch_loss))return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
聚合權重
中心服務器每輪接受到了客戶端傳遞的參數,進行平均聚合(同樣的,也可以采用加權聚合),然后再下發給每個客戶端
def FedAvg(w):w_avg = copy.deepcopy(w[0]) # 對第一個客戶端進行深層拷貝for k in w_avg.keys(): # 遍歷每一個參數for i in range(1, len(w)): # 遍歷每一個客戶端并相加w_avg[k] += w[i][k]w_avg[k] = torch.div(w_avg[k], len(w)) # 最后求平均return w_avg
個性化聯邦學習(Personalized federated learning, PFL)
談論到了聯邦學習(Federated Learning),那就不得不談論PFL了,代碼引用于https://github.com/TsingZ0/PFLlib
上述的PFLib
庫不僅僅有PFL,同樣也集成了FL。跑聯邦的實驗很好用,建議使用此框架完成對比實驗,這里就不再詳細介紹代碼了,PFLib
作者本身做的還是很不錯的,代碼架構清晰明了。
看別人的代碼清晰明了,看自己的代碼不堪入目ε(┬┬﹏┬┬)3
實際上,PFL是個偽需求(未來有可能成為真是需求)。我們回顧一下FL的誕生,FL誕生之前是分布式學習,搖身一變成為了聯邦學習。再過幾年,個性化聯邦學習應運而生。FL是為了隱私保護,保護各個客戶端的數據不被其他人獲取,但是依然希望獲取一個泛化性能較強的模型,而PFL為了追求個性化,其實是在泛化性和個性化作了一個平衡,但是隨時間各個論文為了提點,不得不更偏向于個性化
過度個性化 ≈ \approx ≈ 本地訓練模型
從這里就可以看到,目前近幾年部分論文在底層上,實際與本地訓練個模型別無二致,無非添加了些概念(截至2025.6)
談論到這里,其實還一種學術思路,便是將其他領域的方法嫁接到本領域當中,比如對抗學習、元學習、對比學習、知識蒸餾等等(這些都已有論文了)。所以往往不要在本領域尋找創新點,創新點可能在其它領域中等待著發現。