Introduction to FLOWER.

Ankita Sinha
Nerd For Tech
Published in
4 min readMay 31, 2021

--

In my earlier post, I covered what Federated Learning is at a broad level. Here, I will walk you through how to set up your own Federated Learning based model using a framework called Flower.

We will look at a cross-device and asynchronous design. This is very similar to GBoard and Siri, where a local model resides on the edge device (your phone/ mac) in this case.

Federated Learning

Let’s jump right into the components needed to build your own model. A federated learning system needs two parts

  1. Server
  2. Client.

The Data Scientist has full control over the server. The server hosts the aggregation logic and makes sure all the devices have the latest and updated model parameters.

The clients (devices) have a local model running on the local data.

In our use case, we will be following the below steps.

  1. We will build a simple pytorch based neural network model to read images and classify them.
  2. We will first train the model on the local data in client. Lets start with 3 devices, so we have 3 locally running models in 3 seperate devices.
  3. Once our model is trained and we have our model parameters, we try to connect with the server.
  4. The server then either accepts or rejects the invitation to connect based on some policy. Here we will simply use a First Come First Serve policy.
  5. If the connection goes through, the client sends the model parameters to the server.
  6. The server waits for all 3 model parameters and then aggregates them thus making use of all the data in all the models.
  7. This can happen for as many epochs as we want to train the data.
  8. Then the server sends the updates weight parameters back to the clients.
  9. The client will now use the weights for image classification.

Lets create a file caller server.py and add the following lines:

import flwr as fl

# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.1,
min_available_clients=3
)
fl.server.start_server("[::]:8080", config={"num_rounds": 3}, , strategy=strategy)

This is all we need to start a server bound to localhost. Strategy is our policy. Num_rounds specifies that the training is going to continue for 3 rounds. Each round can have a different set of clients based on which devices connected first. fraction_fit samples 10% of all the available clients in each round. min_available_client is the min number of clients needed to be connected for the training to start. You can find the various ways you can define your policy here.

You can host your server.py in AWS in EC2 or Sagemaker. Or run it in your own workstation.

Now lets write our client. You can find the colab notebook here and the code on git here. Flwr is based on GRPC which is not there in the free collab version. You can create a docker instance or run it on your workstation.

The only difference between a non FL based ANN and FL lies in connecting to the server and getting the updated weights. We will look into the Federated Learning section of the notebook here.

class CifarClient(fl.client.NumPyClient):def get_parameters(self):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(), len(trainloader), {}
def evaluate(self, parameters, config): self.set_parameters(parameters) loss, accuracy = test(net, testloader) return float(loss), len(testloader), {"accuracy": accuracy}

This is the most important class which implements flowr. the function —

  1. get_parameters: returns model parameters to server as a list of NumPy ndarrays.
  2. set_parameters: sets model parameters in client from a list of NumPy ndarrays.
  3. fit: sets model parameters, train model in client and return updated model parameters to server.
  4. evaluate: sets model parameters, evaluate model on local test dataset in client and return result to server.

now you can run the last cell and check how the accuracy of your model was affected due to centralised training. (Make sure to have 3 copies of client.py as we mentioned the min_available_clients = 3.

References:

https://github.com/adap/flower

--

--