Writing data engineering
data-engineering 14 min read 22 April 2022

Building Production ML Pipelines with TFX

A ground-up walkthrough of TensorFlow Extended — orchestrators, metadata, standard components (ExampleGen through Pusher), and building custom components. Written from hands-on work building ML pipelines in 2022.

Context note: This was written in April 2022 while building ML pipelines using TFX. The framework has evolved since then, and the broader industry has moved toward more cloud-native orchestration (Vertex AI Pipelines, SageMaker Pipelines). The patterns here — metadata tracking, training-serving skew prevention, model validation before push — remain relevant. The specific TFX APIs may differ in current versions.


TFX (TensorFlow Extended) is Google’s production ML platform built on top of TensorFlow. It gives you a configurable pipeline framework that handles the entire lifecycle: data ingestion, validation, feature engineering, training, evaluation, and serving. The goal is repeatable, auditable ML pipelines that can be orchestrated across different compute backends.

I worked with TFX while building scalable feature pipelines and model training infrastructure. Here’s what I learned.

Why TFX Exists

The core problem TFX solves: most ML teams have research code that works in notebooks and production code that’s a different system entirely. Features computed differently at training vs. serving time cause silent errors — the model was trained on normalized data but gets raw data in production. Debugging this is painful.

TFX addresses this by building transformations into the TensorFlow graph itself, so the exact same preprocessing runs at training and inference time. It also tracks all metadata — what data was used, which model version was trained, what evaluation results were — so you can trace any prediction back to its training run.

Orchestrators

TFX pipelines need an orchestrator to manage task execution, scheduling, and dependency resolution. Three options are supported:

Apache Beam — unified batch and stream API that runs locally or on distributed backends (Spark, Dataflow). Best for development and testing, or when you want Dataflow for scale without Kubernetes overhead.

Apache Airflow — the DAG-based scheduler most data teams already have. Good fit if your data platform runs on Airflow. TFX components become Airflow tasks; dependencies between them are tracked via MLMD.

Kubeflow Pipelines — Kubernetes-native, containerized. Best for cloud-scale production deployments where you need GPU scheduling, distributed training, and fine-grained resource control.

The Metadata Store (MLMD)

ML Metadata (MLMD) is the backbone of TFX. Every artifact every component produces — data statistics, schemas, transformed datasets, trained models, evaluation results — gets recorded in MLMD.

This gives you:

In production TFX, MLMD uses MySQL or persistent disk storage. Locally, it uses SQLite.

Understanding the Core Abstractions

Artifacts

Every step’s output is an artifact — a typed pointer to data. examples, statistics, schema, transform_graph, model, blessing are all artifact types. Downstream components declare which artifact types they consume.

Parameters

Pipeline-level configuration known before execution — number of training steps, eval thresholds, file paths. Passed as protobuf definitions, not hardcoded.

Components

Each component has: a spec (input/output artifact types + parameters), an executor (the actual logic), and a component interface. Standard components are provided; you can write custom ones following the same interface.

Pipeline

A DAG of component instances. TFX infers execution order from artifact dependencies — you don’t manually specify task ordering.

Standard Components: The Full Pipeline

ExampleGen — Data Ingestion

The entry point. ExampleGen reads raw data and emits tf.Example records in TFRecord format, which downstream components consume efficiently.

Supports: CSV, TFRecord, Avro, Parquet, BigQuery out of the box.

Key concept: Span/Version/Split. Data is organized into Spans (time-based groupings), Versions (revisions of a span), and Splits (train/eval). This lets pipelines process new daily data consistently:

input = proto.Input(splits=[
    proto.Input.Split(name='train', pattern='span-{SPAN}/ver-{VERSION}/train/*'),
    proto.Input.Split(name='eval', pattern='span-{SPAN}/ver-{VERSION}/eval/*')
])
example_gen = CsvExampleGen(input_base='/data', input_config=input)

StatisticsGen — Data Profiling

Computes descriptive statistics over training and serving data using Apache Beam (scales to large datasets). Output feeds both SchemaGen and ExampleValidator.

stats_gen = StatisticsGen(examples=example_gen.outputs['examples'])

On subsequent pipeline runs, you can pass a manually curated schema back in so statistics are computed relative to declared expectations — not inferred from scratch.

SchemaGen — Schema Inference

Infers a data schema from statistics: feature types, expected ranges, required fields, vocabulary sizes for categoricals. The auto-generated schema is a starting point; you review and modify it, then import the curated version back via ImportSchemaGen.

schema_gen = SchemaGen(statistics=stats_gen.outputs['statistics'])
# Later runs, after curation:
schema_gen = ImportSchemaGen(schema_file='/path/to/curated_schema.pbtxt')

ExampleValidator — Data Quality

Validates each data batch against the schema. Catches:

validator = ExampleValidator(
    statistics=stats_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema']
)

Transform — Feature Engineering

The most important component for production correctness. Transform runs a preprocessing_fn that defines all feature transformations, then bakes the result into a TensorFlow SavedModel. The exact same graph is applied at training and serving time — no skew possible.

def preprocessing_fn(inputs):
    outputs = {}
    # Z-score normalization
    outputs['fare_normalized'] = tft.scale_to_z_score(inputs['fare'])
    # Vocabulary mapping
    outputs['payment_type_index'] = tft.compute_and_apply_vocabulary(
        inputs['payment_type'], top_k=100
    )
    # Bucketization
    outputs['trip_distance_bucket'] = tft.bucketize(
        inputs['trip_distance'], num_buckets=10
    )
    return outputs

transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file='preprocessing.py'
)

Global statistics (like scale_to_z_score’s mean and variance) are computed once over the full training set during the Transform run, then frozen into the SavedModel. At serving time, the model applies those frozen statistics — no re-computation needed.

Trainer — Model Training

Trains the model using a run_fn or trainer_fn defined in a module file. Takes transformed examples, the transform graph, and optional pre-trained model for warmstart.

trainer = Trainer(
    module_file='model.py',  # Contains run_fn
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000)
)

The run_fn receives a FnArgs object with paths to data, the transform graph, and the serving model directory:

def run_fn(fn_args: FnArgs) -> None:
    model = build_keras_model()
    model.fit(
        train_dataset,
        validation_data=eval_dataset,
        steps_per_epoch=fn_args.train_steps
    )
    model.save(fn_args.serving_model_dir)

Optional component that wraps KerasTuner to search hyperparameter space before committing to a full training run. Once good hyperparameters are found, import them directly and skip re-tuning:

hparams_importer = Importer(
    source_uri='path/to/best_hyperparameters.txt',
    artifact_type=HyperParameters
).with_id('import_hparams')

trainer = Trainer(
    ...
    hyperparameters=hparams_importer.outputs['result']
)

Evaluator — Model Validation

Runs TensorFlow Model Analysis on the trained model against eval data. Computes sliced metrics (e.g., accuracy broken down by category) and compares against a baseline model. If the new model doesn’t pass the evaluation threshold, it doesn’t get pushed — no bad models reach production silently.

InfraValidator — Serving Infrastructure Check

Validates that the model loads correctly in the actual serving infrastructure before it’s pushed. Catches compatibility issues (TF version mismatches, incompatible SavedModel signatures) that Evaluator won’t catch because Evaluator only tests model quality.

Pusher — Model Deployment

Pushes a validated model to a serving destination (filesystem, TF Serving, Vertex AI). Only pushes if both Evaluator and InfraValidator have blessed the model.

pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    infra_blessing=infra_validator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory='/serving/models'
        )
    )
)

BulkInferrer — Batch Inference

For running predictions on unlabeled data at scale. Takes a trained model and a dataset of examples, produces inference results without standing up a serving endpoint. Useful for offline scoring pipelines.

Custom Components

When standard components don’t cover a data source or processing step, TFX lets you write custom ones following the same artifact interface. The most common case is a custom ExampleGen for proprietary data formats.

Custom ExampleGen components use Apache Beam to parallelize ingestion across arbitrary data sources — databases, APIs, custom binary formats. The output is always tf.Example records, so downstream components work unchanged.

The Full Pipeline DAG

A complete TFX pipeline for a supervised ML problem looks like:

ExampleGen → StatisticsGen → SchemaGen → ExampleValidator

                              Transform → Trainer → Evaluator → Pusher

                                                    InfraValidator

ExampleValidator and Transform can run in parallel after SchemaGen since they share inputs but don’t depend on each other’s outputs. TFX’s DAG executor figures this out automatically from artifact dependencies.

What I’d Do Differently Now

TFX’s strengths — metadata tracking, training-serving skew prevention, model validation gates — are the right ideas. The friction is in the TF-centric API surface: everything is TFRecord and SavedModel, so bringing non-TF models in requires adapter work.

In 2024, I’d evaluate Vertex AI Pipelines or Metaflow for the same patterns with less ceremony. But the conceptual architecture TFX established — artifact lineage, validation gates, frozen preprocessing graphs — is still the right way to think about production ML.

tfx mlops tensorflow pipelines ml-engineering airflow kubeflow feature-engineering
← All articles

Lets collaborate!

Whether you need a quantitative researcher, an machine learning systems builder, or a technical advisor — I'm available for select consulting engagements.

Get in Touch →