Statistical Typing: A Runtime Type System for Data Science and Machine Learning
Data science and machine learning relies on high quality datasets for visualization, statistical inference, and modeling. Statistical typing is a runtime typing system that enables data scientists, engineers, and analysts to validate real-world data and isolate units of processing, analysis, or model-training logic to implement more robust data testing.
The Data Quality Problem
One of the central challenges in data science (DS) and machine learning (ML) is managing and maintaining data quality. As an ML engineer and practitioner who frequently constructs, cleans, explores, and models proprietary (i.e. non-benchmark) datasets, "bad data" makes all the difference between accurate versus misleading data visualizations, statistical inferences, and models. In this article I want to hone in on three problems that are fairly unique to DS/ML practice when it comes to dealing with tabular data in the Python ecosystem using pandas, which is one of the de facto tools for data manipulation in the DS/ML toolchain.
Tooling for type safety is improving in the Python ecosystem with the broad adoption of the typing module and projects like mypy, which eases the developer experience for writing readable, reliable code.
However, for most DS/ML work, this isn't quite sufficient. This is because logical data types don't always capture the statistical distributions of the variables under study, which is a key thing to do when, for example, your data distribution shifts as a result of a world-wide pandemic, causing ML models to break in unexpected ways.
Having systems in place that fail early (and loudly ๐) when the data distribution is not what you assumed is one of the critical pieces to building reliable production systems, and we need better tooling for this.
Another important tool in the developer's arsenal is testing. There are many books, articles, and discussions available online about different types of testing techniques and how to put them into practice. I won't dive deeply into it here, but in short, testing your code makes it easier to change it and know when you've broken something while at the same time serving as documentation.
Even in exploratory or research contexts, it's a good idea to write tests for your code because it strengthens your confidence in the robustness of the insights that you're taking away from your analysis.
The challenge with testing software compounds when processing data for the purpose of statistical analysis and modeling.
Consider a machine learning pipeline that creates a predictive model from survey responses. The barriers to testing data and transformation code tends to be much higher than the business logic that processes survey responses and stores raw values in a database because the latter tends to be simpler and more atomic by design.
By atomic โ I mean that each piece of data that's filled out by respondents and stored in the database can be tested in isolation without having to analyze the aggregate statistical patterns across a larger sample of responses. On the other hand, for my statistical analysis to make sense, the overall integrity of the statistical distribution ๐ of the responses needs to be taken into account.
Because the effort that goes into exploring, cleaning, and figuring out how to test my dataset is so high, I'm discouraged from writing tests for my pipeline code. As the famous software development quip goes:
"legacy code is code without tests"
- Michael C. Feathers, Working Effectively with Legacy Code
But you might think: "but when I put my code in production, surely I โ or one of my collaborators โ will write tests then?". The thing is, regardless of who writes those tests or when they're written, someone'll have to do it at some point, so the sooner and more quickly you can climb the technical debt mountain ๐ the better!
In the rest of this post I'll try to convince you that statistical typing gives us the tools we need to do just that.
What's Statistical Typing?
If you've used strong, statically-typed languages before, or the mypy static type checker with type-hinting in Python, you may have noticed that type definitions can often catch nasty type-related bugs ๐ that render certain kinds of unit tests unnecessary. Other tools, like pydantic, enforce types at runtime via a data parsing model.
Statistical typing extends the concept of logical data types to the class of statistical data types and, ultimately, probability distributions. Statistical data types builds on top of logical data types, and in fact there's considerable overlap between the two.
For example, the binary logical data type is also a statistical data type. The key difference is that statistical data types hold additional semantics that govern the kinds of statistical operations that we can perform on variables of a particular type and probability distributions that describe those variables.
What if you can specify the set of acceptable values that a variable can take, from the data type, set/range of values, or even what distribution a particular variable is drawn from? This is the goal of statistical typing: to enumerate a practical set of constraints that specify what should be considered valid data for a particular dataset.
For example, we might want a categorical variable to be drawn somewhat uniformly
from a set of values {A, B, C}
. We can express this as a hypothesis test that
causes our pipeline to fail if any one of the values occurs significantly more
frequently than the others, given a pre-defined level of statistical significance.
Or we may want a real-valued variable to be drawn roughly from a normal
distribution with mean ยต
and variance ฯ
, which can also be specified with an
alpha value that we deem acceptable for a particular analysis.
In essence, a statistical typing implementation involves specifying three kinds of metadata for a given set of variables:
- logical data types, e.g.
int
,str
,float
, etc. - deterministic properties, e.g.
categorical
values andreal-valued
ranges - probabilistic properties, e.g. sufficient statistics like
mean
andstandard deviation
The challenge presented by item 3 is obvious: discovering the underlying probability distributions of real-world data is often non-trivial. However, even if we can only express or automatically infer these metadata up to point 2, we can still get something quite powerful ๐ช: property-based testing of statistical analysis code.
With a statistically-typed dataframe, not only can we validate real-world data to ensure that our assumptions about them hold up, but we can also test our data transformation, analysis, and modeling code given valid samples according to our schema definition. Statistical typing effectively gives DS/ML practitioners the tools to easily isolate their code from real-world data, providing a convenient way of implementing unit tests.
Let me illustrate how these concepts would work in practice with a toy problem using pandera, a runtime data validation library for pandas dataframes that I've been developing over the last few years.
Suppose you're building a predictive model of house prices given features about different houses:
raw_data = """
square_footage,n_bedrooms,property_type,price
750,1,condo,200000
900,2,condo,400000
1200,2,house,500000
1100,3,house,450000
1000,2,condo,300000
1000,2,townhouse,300000
1200,2,townhouse,350000
"""
In the raw data above you can see that we have the following columns:
- feature 1:
square_footage
- feature 2:
n_bedrooms
- feature 3:
property_type
- target:
price
Our modeling pipeline will involve two steps:
def process_data(raw_data): # step 1: prepare data for model training
...
def train_model(processed_data): # step 2: fit a model on processed data
...
At its core, pandera
provides a flexible and expressive API for defining
dataframe schemas and seamlessly integrating data validation logic into your
data analysis pipelines, all while separating the concerns of data cleaning
and validation.
import pandera as pa
from pandera.typing import Series
PROPERTY_TYPES = ["condo", "townhouse", "house"]
class BaseSchema(pa.SchemaModel):
square_footage: Series[int] = pa.Field(in_range={"min_value": 0, "max_value": 3000})
n_bedrooms: Series[int] = pa.Field(in_range={"min_value": 0, "max_value": 10})
price: Series[float] = pa.Field(in_range={"min_value": 0, "max_value": 1000000})
class Config:
coerce = True
class RawData(BaseSchema):
property_type: Series[str] = pa.Field(isin=PROPERTY_TYPES)
class ProcessedData(BaseSchema):
property_type_condo: Series[int] = pa.Field(isin=[0, 1])
property_type_house: Series[int] = pa.Field(isin=[0, 1])
property_type_townhouse: Series[int] = pa.Field(isin=[0, 1])
In the code above, we can see that we're defining a BaseSchema
, which shares
columns that are common between the raw and processed data. We're also making
sure that the columns are coerced to the expected data types during validation.
RawData
and ProcessedData
inherit from BaseSchema
, and just by looking
at them we can see the difference that we expect between the raw and processed data:
our process_data
function should convert the property_type
categorical
variable into a set of dummy variables.
import pandas as pd
import pandera as pa
from pandera.typing import DataFrame
from sklearn.linear_model import LinearRegression
@pa.check_types
def process_data(raw_data: DataFrame[RawData]) -> DataFrame[ProcessedData]:
return pd.get_dummies(
raw_data.astype({"property_type": pd.CategoricalDtype(PROPERTY_TYPES)})
)
@pa.check_types
def train_model(processed_data: DataFrame[ProcessedData]):
estimator = LinearRegression()
targets = processed_data["price"]
features = processed_data.drop("price", axis=1)
estimator.fit(features, targets)
return estimator
Now every time we run our pipeline our data is validated as it passes through the various transformations:
from io import StringIO
def run_pipeline(raw_data):
processed_data = process_data(raw_data)
estimator = train_model(processed_data)
# evaluate model, save artifacts, etc...
print("model training successful!")
run_pipeline(pd.read_csv(StringIO(raw_data.strip())))
So if we pass invalid data into run_pipeline
, we should get an error:
invalid_data = """
square_footage,n_bedrooms,property_type,price
750,1,unknown,200000
900,2,condo,400000
1200,2,house,500000
"""
try:
run_pipeline(pd.read_csv(StringIO(invalid_data.strip())))
except Exception as e:
print(e)
Here, pandera
tells exactly what went wrong: the property_type
column has an invalid category unknown
at the 0th entry.
Property-based Testing
But wait, there's more! Since we've already defined our schemas, we can isolate the processing and model-training code from real-world data to test that each component in your pipeline is functioning as expected.
pandera
builds on top of the hypothesis
package to generate synthetic data from search strategies that try to find
the simplest case that would falsify your tests:
import hypothesis
@hypothesis.given(RawData.strategy(size=3))
def test_process_data(raw_data):
process_data(raw_data)
@hypothesis.given(ProcessedData.strategy(size=3))
def test_train_model(processed_data):
estimator = train_model(processed_data)
preds = estimator.predict(processed_data.drop("price", axis=1))
assert len(preds) == processed_data.shape[0]
def run_test_suite():
test_process_data()
test_train_model()
print("โ
tests successful!")
run_test_suite()
So if we were to incorrectly implement any of the components in our pipeline,
we'd see errors early on. In this case, we're just going to return the raw data
without the dummified property_type
variable.
@pa.check_types
def process_data(raw_data: DataFrame[RawData]) -> DataFrame[ProcessedData]:
return raw_data
try:
run_test_suite()
except Exception as e:
print(e)
Here, our test suite catches the fact that property_type_condo
doesn't exist
in our processed data output.
We can get some more intuition about what's going on with the data synthesis
strategies by interactively generating data using the example
method.
RawData.example(size=3)
ProcessedData.example(size=3)
Under the hood, pandera
is collecting all of the schema properties and
converting it into a search strategy using the
pandas-supported hypothesis strategies.
Currently, one limitation that you can see from the ProcessedData
example above is that the
generated data doesn't quite capture the joint distribution between the property_type_*
dummy
variables, as the second row contains 1
s for all of the property types. Depending on
what exactly it is you're trying to test, this may or may not matter. Ultimately, it's
still up to you to determine what to test and how ๐ค.
What's Next?
There's still a lot to do in pandera
to fully-realize the vision of
statistical typing, but I think the main API ideas and features are there
to get started and reap the benefits of statistical typing:
- Runtime data validation when executing pipeline during development/production.
- Property-based unit testing by isolating transformation code from real data.
- Self-documenting pipelines that explicitly define the types and statistical properties of data as it flows through your pipeline.
There are a few things in the roadmap that I'm excited about:
- Decouple
pandera
andpandas
type systems - Add support for parallelized dataframes for larger datasets
- Add a more comprehensive suite of built-in statistical hypothesis tests
- Implement data synthesis strategies for hypothesis tests
- Support data synthesis strategies for joint distributions
- Support machine-learning-specific schemas
If you're interested in this project, please consider helping out with code contributions, submitting feature requests, bugs, documentation improvements, and support!