Context note: This was written in April 2022 while building ML pipelines using TFX. The framework has evolved since then, and the broader industry has moved toward more cloud-native orchestration (Vertex AI Pipelines, SageMaker Pipelines). The patterns here — metadata tracking, training-serving skew prevention, model validation before push — remain relevant. The specific TFX APIs may differ in current versions.
TFX (TensorFlow Extended) is Google’s production ML platform built on top of TensorFlow. It gives you a configurable pipeline framework that handles the entire lifecycle: data ingestion, validation, feature engineering, training, evaluation, and serving. The goal is repeatable, auditable ML pipelines that can be orchestrated across different compute backends.
I worked with TFX while building scalable feature pipelines and model training infrastructure. Here’s what I learned.
Why TFX Exists
The core problem TFX solves: most ML teams have research code that works in notebooks and production code that’s a different system entirely. Features computed differently at training vs. serving time cause silent errors — the model was trained on normalized data but gets raw data in production. Debugging this is painful.
TFX addresses this by building transformations into the TensorFlow graph itself, so the exact same preprocessing runs at training and inference time. It also tracks all metadata — what data was used, which model version was trained, what evaluation results were — so you can trace any prediction back to its training run.
Orchestrators
TFX pipelines need an orchestrator to manage task execution, scheduling, and dependency resolution. Three options are supported:
Apache Beam — unified batch and stream API that runs locally or on distributed backends (Spark, Dataflow). Best for development and testing, or when you want Dataflow for scale without Kubernetes overhead.
Apache Airflow — the DAG-based scheduler most data teams already have. Good fit if your data platform runs on Airflow. TFX components become Airflow tasks; dependencies between them are tracked via MLMD.
Kubeflow Pipelines — Kubernetes-native, containerized. Best for cloud-scale production deployments where you need GPU scheduling, distributed training, and fine-grained resource control.
The Metadata Store (MLMD)
ML Metadata (MLMD) is the backbone of TFX. Every artifact every component produces — data statistics, schemas, transformed datasets, trained models, evaluation results — gets recorded in MLMD.
This gives you:
- Lineage: trace any model back to the exact dataset it was trained on
- Caching: skip re-running components if inputs haven’t changed
- Comparison: compare current model metrics against previous runs before deciding to push
In production TFX, MLMD uses MySQL or persistent disk storage. Locally, it uses SQLite.
Understanding the Core Abstractions
Artifacts
Every step’s output is an artifact — a typed pointer to data. examples, statistics, schema, transform_graph, model, blessing are all artifact types. Downstream components declare which artifact types they consume.
Parameters
Pipeline-level configuration known before execution — number of training steps, eval thresholds, file paths. Passed as protobuf definitions, not hardcoded.
Components
Each component has: a spec (input/output artifact types + parameters), an executor (the actual logic), and a component interface. Standard components are provided; you can write custom ones following the same interface.
Pipeline
A DAG of component instances. TFX infers execution order from artifact dependencies — you don’t manually specify task ordering.
Standard Components: The Full Pipeline
ExampleGen — Data Ingestion
The entry point. ExampleGen reads raw data and emits tf.Example records in TFRecord format, which downstream components consume efficiently.
Supports: CSV, TFRecord, Avro, Parquet, BigQuery out of the box.
Key concept: Span/Version/Split. Data is organized into Spans (time-based groupings), Versions (revisions of a span), and Splits (train/eval). This lets pipelines process new daily data consistently:
input = proto.Input(splits=[
proto.Input.Split(name='train', pattern='span-{SPAN}/ver-{VERSION}/train/*'),
proto.Input.Split(name='eval', pattern='span-{SPAN}/ver-{VERSION}/eval/*')
])
example_gen = CsvExampleGen(input_base='/data', input_config=input)
StatisticsGen — Data Profiling
Computes descriptive statistics over training and serving data using Apache Beam (scales to large datasets). Output feeds both SchemaGen and ExampleValidator.
stats_gen = StatisticsGen(examples=example_gen.outputs['examples'])
On subsequent pipeline runs, you can pass a manually curated schema back in so statistics are computed relative to declared expectations — not inferred from scratch.
SchemaGen — Schema Inference
Infers a data schema from statistics: feature types, expected ranges, required fields, vocabulary sizes for categoricals. The auto-generated schema is a starting point; you review and modify it, then import the curated version back via ImportSchemaGen.
schema_gen = SchemaGen(statistics=stats_gen.outputs['statistics'])
# Later runs, after curation:
schema_gen = ImportSchemaGen(schema_file='/path/to/curated_schema.pbtxt')
ExampleValidator — Data Quality
Validates each data batch against the schema. Catches:
- Missing required features
- Out-of-range values
- Training-serving skew (when you provide both training and serving statistics)
- Data drift across consecutive pipeline runs
validator = ExampleValidator(
statistics=stats_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
Transform — Feature Engineering
The most important component for production correctness. Transform runs a preprocessing_fn that defines all feature transformations, then bakes the result into a TensorFlow SavedModel. The exact same graph is applied at training and serving time — no skew possible.
def preprocessing_fn(inputs):
outputs = {}
# Z-score normalization
outputs['fare_normalized'] = tft.scale_to_z_score(inputs['fare'])
# Vocabulary mapping
outputs['payment_type_index'] = tft.compute_and_apply_vocabulary(
inputs['payment_type'], top_k=100
)
# Bucketization
outputs['trip_distance_bucket'] = tft.bucketize(
inputs['trip_distance'], num_buckets=10
)
return outputs
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file='preprocessing.py'
)
Global statistics (like scale_to_z_score’s mean and variance) are computed once over the full training set during the Transform run, then frozen into the SavedModel. At serving time, the model applies those frozen statistics — no re-computation needed.
Trainer — Model Training
Trains the model using a run_fn or trainer_fn defined in a module file. Takes transformed examples, the transform graph, and optional pre-trained model for warmstart.
trainer = Trainer(
module_file='model.py', # Contains run_fn
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000)
)
The run_fn receives a FnArgs object with paths to data, the transform graph, and the serving model directory:
def run_fn(fn_args: FnArgs) -> None:
model = build_keras_model()
model.fit(
train_dataset,
validation_data=eval_dataset,
steps_per_epoch=fn_args.train_steps
)
model.save(fn_args.serving_model_dir)
Tuner — Hyperparameter Search
Optional component that wraps KerasTuner to search hyperparameter space before committing to a full training run. Once good hyperparameters are found, import them directly and skip re-tuning:
hparams_importer = Importer(
source_uri='path/to/best_hyperparameters.txt',
artifact_type=HyperParameters
).with_id('import_hparams')
trainer = Trainer(
...
hyperparameters=hparams_importer.outputs['result']
)
Evaluator — Model Validation
Runs TensorFlow Model Analysis on the trained model against eval data. Computes sliced metrics (e.g., accuracy broken down by category) and compares against a baseline model. If the new model doesn’t pass the evaluation threshold, it doesn’t get pushed — no bad models reach production silently.
InfraValidator — Serving Infrastructure Check
Validates that the model loads correctly in the actual serving infrastructure before it’s pushed. Catches compatibility issues (TF version mismatches, incompatible SavedModel signatures) that Evaluator won’t catch because Evaluator only tests model quality.
Pusher — Model Deployment
Pushes a validated model to a serving destination (filesystem, TF Serving, Vertex AI). Only pushes if both Evaluator and InfraValidator have blessed the model.
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
infra_blessing=infra_validator.outputs['blessing'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory='/serving/models'
)
)
)
BulkInferrer — Batch Inference
For running predictions on unlabeled data at scale. Takes a trained model and a dataset of examples, produces inference results without standing up a serving endpoint. Useful for offline scoring pipelines.
Custom Components
When standard components don’t cover a data source or processing step, TFX lets you write custom ones following the same artifact interface. The most common case is a custom ExampleGen for proprietary data formats.
Custom ExampleGen components use Apache Beam to parallelize ingestion across arbitrary data sources — databases, APIs, custom binary formats. The output is always tf.Example records, so downstream components work unchanged.
The Full Pipeline DAG
A complete TFX pipeline for a supervised ML problem looks like:
ExampleGen → StatisticsGen → SchemaGen → ExampleValidator
↓
Transform → Trainer → Evaluator → Pusher
↓
InfraValidator
ExampleValidator and Transform can run in parallel after SchemaGen since they share inputs but don’t depend on each other’s outputs. TFX’s DAG executor figures this out automatically from artifact dependencies.
What I’d Do Differently Now
TFX’s strengths — metadata tracking, training-serving skew prevention, model validation gates — are the right ideas. The friction is in the TF-centric API surface: everything is TFRecord and SavedModel, so bringing non-TF models in requires adapter work.
In 2024, I’d evaluate Vertex AI Pipelines or Metaflow for the same patterns with less ceremony. But the conceptual architecture TFX established — artifact lineage, validation gates, frozen preprocessing graphs — is still the right way to think about production ML.