Running a training server on AWS Fargate

Learn how to run your integrate.ai training server on AWS Fargate

To follow along with this tutorial, open the integrateai_fargate_server.ipynb notebook, which you can find in the sample_packages/sample_notebook folder in the SDK package.

Configure AWS Fargate

If this is your first time running a training session with a server in AWS Fargate, ensure that you have followed the instructions for Setting up AWS Fargate before you continue.

Running AWS Batch jobs in Fargate through the SDK

Install the SDK

Authenticate to the API client

First, the client must be authenticated.

import os
IAI_TOKEN = os.environ.get("IAI_TOKEN")

from integrate_ai_sdk.api import connect
client = connect(token=IAI_TOKEN)

Model config and data schema

Set up your model configuration and data schema for your training session. For detailed information, see Building a Custom Model. A generic example that matches the sample notebook is provided below.

Example model config & data schema
model_config = {
    "experiment_name": "test_synthetic_tabular",
    "experiment_description": "test_synthetic_tabular",
    "strategy": {"name": "FedAvg", "params": {}},
    "model": {"params": {"input_size": 15, "hidden_layer_sizes": [6, 6, 6], "output_size": 2}},
    "balance_train_datasets": False,
    "ml_task": {
        "type": "classification",
        "params": {
            "loss_weights": None,
        },
    },
    "optimizer": {"name": "SGD", "params": {"learning_rate": 0.2, "momentum": 0.0}},
    "differential_privacy_params": {"epsilon": 4, "max_grad_norm": 7},
    "save_best_model": {
        "metric": "loss",  # to disable this and save model from the last round, set to None
        "mode": "min",
    },
    "seed": 23,  # for reproducibility
}

# Example data schema
data_schema = {
    "predictors": ["x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14"],
    "target": "y",
}

Create a training session

This example session uses 2 clients and 2 rounds. The training_session definition is passed to the server as part of the task definition.

Note: If you are using a custom model, ensure that you specify the correct model_config and data_schema.

training_session = client.create_fl_session(
    name="Fargate Testing notebook",
    description="I am testing session creation with Fargate through a notebook",
    min_num_clients=2,
    num_rounds=2,
    package_name="iai_ffnet",
    model_config=model_config,  
    data_config=data_schema,
).start()

# Print the training session ID
training_session.id 

Specifying optional AWS Credentials

If you are generating temporary AWS credentials, specify them here. Otherwise use the default profile credentials.

aws_creds = {
    'ACCESS_KEY': os.environ.get("AWS_ACCESS_KEY_ID"),
    'SECRET_KEY': os.environ.get("AWS_SECRET_ACCESS_KEY"),
    'SESSION_TOKEN': os.environ.get("AWS_SESSION_TOKEN"),
    'REGION': os.environ.get("AWS_REGION"),
}

Specify the Fargate Cluster, Task Definition Name and Network Parameters

Configure the cluster, task definition, and network parameters on AWS first, then specify them as variables for the SDK.

# Specify the name of your cluster, task definition and network parameters
fargate_cluster = "<fargate_cluster_name>"
task_def = "<fargate_task_definition>"
subnet_id = "<vpc_network_subnet_id>"
security_group = "<iai_server_security_group>"

train_path1 = '{train path 1}'
train_path2 = '{train path 2}'
test_path = '{test path}'
job_queue= '<fargate-job-queue>'
job_def='<iai-server-fargate-job>'

model_storage='{model storage path}'

With the credentials and variables defined, you can now use the SDK to run the training server on AWS Fargate.

Run the training server

The SDK provides a taskgroup and taskbuilder object to simplify the process of creating and managing Fargate and AWS Batch tasks.

Create a Fargate task builder object

from integrate_ai_sdk.taskgroup.taskbuilder import aws as taskbuilder_aws

tb = taskbuilder_aws.fargate(
    aws_credentials=<aws_creds>,
    cluster=<fargate_cluster>,
    task_definition=<task_def>)

Create an AWS Batch task builder object

tb_batch = taskbuilder_aws.batch( 
    aws_credentials=<aws_creds>,
    job_queue=<job_queue>,
    cpu_job_definition=<job_def>)

Create a taskgroup

The taskgroup starts the server and the batch.

Here we are creating a session task group that takes as input the training_session created earlier. The first task added (tb) starts the server. The tb_batch task is added twice - once for each client.

Tip: See Create a training session to review the session definition.

from integrate_ai_sdk.taskgroup.base import SessionTaskGroup

task_group_context = SessionTaskGroup(training_session)
    .add_task(tb.fls(subnet_id, security_group, ssm_token_key, storage_path=model_storage))
    .add_task(tb_batch.hfl(train_path=train_path1, test_path=test_path, vcpus='2', memory='16384'))
        .add_task(tb_batch.hfl(train_path=train_path2, test_path=test_path, vcpus='2', memory='16384')
        ).start()

You can monitor the running server to check training progress.

task_group_context.wait(300)

Last updated