Private Record Linkage (PRL) sessions

Private record linkage sessions create intersection and alignment among datasets to prepare them for vertical federated learning.

In a vertical federated learning process, two or more parties collaboratively train a model using datasets that share a set of overlapping features. These datasets generally each contain distinct data with some overlap. This overlap is used to define the intersection of the sets. Private record linkage (PRL) uses the intersection to create alignment between the sets so that a shared model can be trained.

Overlapping records are determined privately through a PRL session, which combines Private Set Intersection with Private Record Alignment.

For example, in data sharing between a hospital (party B, the Active party) and a medical imaging centre (party A, the Passive party), only a subset of the hospital patients will exist in the imaging centre's data. The hospital can run a PRL session to determine the target subset for model training.

PRL Session Overview

In PRL, two parties submit paths to their datasets so that they can be aligned to perform a machine learning task.

  1. ID columns (id_columns) are used to produce a hash that is sent to the server for comparison. The secret for this hash is shared between the clients and the server has no knowledge of it. This comparison is the Private Set Intersection (PSI) part of PRL.

  2. Once compared, the server orchestrates the data alignment because it knows which indices of each dataset are in common. This is the Private Record Alignment (PRA) part of PRL.

For more information about privacy when performing PRL, see PRL Privacy for VFL.

PRL Session Example

Use the integrateai_fargate_batch_client_vfl.ipynb notebook to follow along and test the examples shown below by filling in your own variables as required.

This example uses AWS Fargate and Batch to run the session using data in S3 buckets.

  1. Complete the Environment Setup.

  2. Ensure that you have the correct roles and policies for Fargate and Batch. See Using AWS Batch with integrate.ai and Running a training server on AWS Fargate for details.

  3. Authenticate to the integrate.ai API client.

  4. Create a configuration for the PRL session.

    1. Specify a prl_data_config that indicates the columns to use as identifiers when linking the datasets to each other. The number of items in the config specifies the number of expected clients. In this example, there are two items and therefore two clients. In the example below, there are two clients submitting data and their datasets are linked by the "id" column in any provided datasets.

prl_data_config = {
    "clients": {
        "passive_client": {"id_columns": ["id"]},
        "active_client": {"id_columns": ["id"],},
    }
}

Optional Parameters

The prl_data_config accepts two optional parameters:

  • match_threshold - The level of accuracy to use for comparison between records. The value must be between 0 and 1.0 inclusive.

  • similarity_function - The similarity function to use. Options are dice or hamming.

Create the session

To create the session, specify the data_config that contains the client names and columns to use as identifiers to link the datasets. For example: prl_data_config.

These client names are referenced for the compute on the PRL session and for any sessions that use the PRL session downstream.

prl_session = client.create_prl_session(
    name="Testing notebook - PRL",
    description="I am testing PRL session creation through a notebook",
    data_config=prl_data_config
).start()

prl_session.id

Specify AWS parameters and credentials

Specify the paths to the datasets and the AWS Batch job information.

The train and test files can be either .csv or .parquet format.

# Example data paths in s3 
active_train_path = 's3://<path to dataset>/active_train.csv'
passive_train_path = 's3://<path to dataset>/passive_train.csv'
active_test_path = 's3://<path to dataset>/active_test.csv'
passive_test_path = 's3://<path to dataset>/passive_test.csv'

# Specify the AWS parameters
cluster = "iai-server-ecs-cluster"
task_definition = "iai-server-fargate-job"
model_storage = "s3://<path to storage>"
security_group = "iai_server_security_group"
subnet_id = "<subnet>" # Public subnet (routed via IGW)
job_queue='iai-client-batch-job-queue'
job_def='iai-client-batch-job'

Specify your AWS credentials if you are generating temporary ones. Otherwise, use the default profile credentials.

# Set your AWS Credentials if you are generating temporary ones, else use the default profile credentials
aws_creds = {
    'ACCESS_KEY': os.environ.get("AWS_ACCESS_KEY_ID"),
    'SECRET_KEY': os.environ.get("AWS_SECRET_ACCESS_KEY"),
    'REGION': os.environ.get("AWS_REGION"),
}

Set up the task builder and task group

Import the taskbuilder and taskgroup from the SDK.

from integrate_ai_sdk.taskgroup.taskbuilder import aws as taskbuilder_aws
from integrate_ai_sdk.taskgroup.base import SessionTaskGroup

Specify the server and batch information to create the task builder objects.

task_server = taskbuilder_aws.fargate(
    cluster=cluster,
    task_definition=task_definition)

tb = taskbuilder_aws.batch( 
    job_queue=job_queue,
    aws_credentials=aws_creds,
    cpu_job_definition=job_def)

Create a task in the task group for each client. The number of tasks in the task group must match the number of clients specified in the data_config used to create the session.

The train_path, test_path, and client_name must be set for each task. The client_name must be the same name as specified in the data_config file.

The vcpus and memory parameters are optional overrides for the job definition.

task_group_context = SessionTaskGroup(prl_session)\
    .add_task(task_server.fls(subnet_id, security_group, storage_path=model_storage, client=<client>))\
    .add_task(tb.prl(train_path=passive_train_path, test_path=passive_test_path, vcpus='2', memory='16384', client=client, client_name="passive_client"))\
        .add_task(tb.prl(train_path=active_train_path, test_path=active_test_path, vcpus='2', memory='16384', client=client, client_name="active_client")).start()

Monitor submitted jobs

Each task in the task group kicks off a job in AWS Batch. You can monitor the jobs through the console or the SDK.

The following code returns the session ID that is included in the job name.

# session available in group context after submission
print(task_group_context.session.id)

Next, check the status of the tasks.

# status of tasks submitted
task_group_status = task_group_context.status()
for task_status in task_group_status:
    print(task_status)

Submitted tasks are in the pending state until the clients join and the session is started. Once started, the status changes to running.

# Use to monitor if a session has completed successfully or has failed
# You can modify the time to wait as per your specific task
task_group_context.wait(300)

View the overlap statistics

When the session is complete, you can view the overlap statistics for the datasets.

prl_session.metrics().as_dict()

Example result:

{'session_id': '07d0f8358d',
 'federated_metrics': [],
 'client_metrics': {'passive_client': {'train': {'n_records': 14400,
    'n_overlapped_records': 12963,
    'frac_overlapped': 0.9},
   'test': {'n_records': 3600,
    'n_overlapped_records': 3245,
    'frac_overlapped': 0.9}},
  'active_client': {'train': {'n_records': 14400,
    'n_overlapped_records': 12963,
    'frac_overlapped': 0.9},
   'test': {'n_records': 3600,
    'n_overlapped_records': 3245,
    'frac_overlapped': 0.9}}}}

To run a VFL training session on the linked datasets, see VFL FFNet Model Training.

To perform exploratory data analysis on the intersection, see EDA Intersect.

Last updated