diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 04b62370..132cd2ae 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -15,6 +15,8 @@ Difference in differences """ +import re + import arviz as az import numpy as np import pandas as pd @@ -84,6 +86,7 @@ def __init__( formula: str, time_variable_name: str, group_variable_name: str, + post_treatment_variable_name: str = "post_treatment", model=None, **kwargs, ) -> None: @@ -95,6 +98,7 @@ def __init__( self.formula = formula self.time_variable_name = time_variable_name self.group_variable_name = group_variable_name + self.post_treatment_variable_name = post_treatment_variable_name self.input_validation() y, X = dmatrices(formula, self.data) @@ -128,6 +132,12 @@ def __init__( } self.model.fit(X=self.X, y=self.y, coords=COORDS) elif isinstance(self.model, RegressorMixin): + # For scikit-learn models, automatically set fit_intercept=False + # This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute + # without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary + # TODO: later, this should be handled in ScikitLearnAdaptor itself + if hasattr(self.model, "fit_intercept"): + self.model.fit_intercept = False self.model.fit(X=self.X, y=self.y) else: raise ValueError("Model type not recognized") @@ -173,7 +183,7 @@ def __init__( # just the treated group .query(f"{self.group_variable_name} == 1") # just the treatment period(s) - .query("post_treatment == True") + .query(f"{self.post_treatment_variable_name} == True") # drop the outcome variable .drop(self.outcome_variable_name, axis=1) # We may have multiple units per time point, we only want one time point @@ -189,7 +199,10 @@ def __init__( # INTERVENTION: set the interaction term between the group and the # post_treatment variable to zero. This is the counterfactual. for i, label in enumerate(self.labels): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): new_x.iloc[:, i] = 0 self.y_pred_counterfactual = self.model.predict(np.asarray(new_x)) @@ -198,31 +211,44 @@ def __init__( # This is the coefficient on the interaction term coeff_names = self.model.idata.posterior.coords["coeffs"].data for i, label in enumerate(coeff_names): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): self.causal_impact = self.model.idata.posterior["beta"].isel( {"coeffs": i} ) elif isinstance(self.model, RegressorMixin): # This is the coefficient on the interaction term - # TODO: CHECK FOR CORRECTNESS - self.causal_impact = ( - self.y_pred_treatment[1] - self.y_pred_counterfactual[0] - ).item() + # Store the coefficient into dictionary {intercept:value} + coef_map = dict(zip(self.labels, self.model.get_coeffs())) + # Create and find the interaction term based on the values user provided + interaction_term = ( + f"{self.group_variable_name}:{self.post_treatment_variable_name}" + ) + matched_key = next((k for k in coef_map if interaction_term in k), None) + att = coef_map.get(matched_key) + self.causal_impact = att else: raise ValueError("Model type not recognized") return def input_validation(self): + # Validate formula structure and interaction interaction terms + self._validate_formula_interaction_terms() + """Validate the input data and model formula for correctness""" - if "post_treatment" not in self.formula: + # Check if post_treatment_variable_name is in formula + if self.post_treatment_variable_name not in self.formula: raise FormulaException( - "A predictor called `post_treatment` should be in the formula" + f"Missing required variable '{self.post_treatment_variable_name}' in formula" ) - if "post_treatment" not in self.data.columns: + # Check if post_treatment_variable_name is in data columns + if self.post_treatment_variable_name not in self.data.columns: raise DataException( - "Require a boolean column labelling observations which are `treated`" + f"Missing required column '{self.post_treatment_variable_name}' in dataset" ) if "unit" not in self.data.columns: @@ -236,6 +262,61 @@ def input_validation(self): coded. Consisting of 0's and 1's only.""" ) + def _get_interaction_terms(self): + """ + Extract interaction terms from the formula. + Returns a list of interaction terms (those with '*' or ':'). + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Remove whitespace + formula = self.formula.replace(" ", "") + + # Extract right-hand side of the formula + rhs = formula.split("~")[1] + + # Split terms by '+' or '-' while keeping them intact + terms = re.split(r"(?=[+-])", rhs) + + # Clean up terms and get interaction terms (those with '*' or ':') + interaction_terms = [] + for term in terms: + # Remove leading + or - for processing + clean_term = term.lstrip("+-") + if any(indicator in clean_term for indicator in INTERACTION_INDICATORS): + interaction_terms.append(clean_term) + + return interaction_terms + + def _validate_formula_interaction_terms(self): + """ + Validate that the formula contains at most one interaction term and no three-way or higher-order interactions. + Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables. + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Get interaction terms + interaction_terms = self._get_interaction_terms() + + # Check for interaction terms with more than 2 variables (more than one '*' or ':') + for term in interaction_terms: + total_indicators = sum( + term.count(indicator) for indicator in INTERACTION_INDICATORS + ) + if ( + total_indicators >= 2 + ): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols) + raise FormulaException( + f"Formula contains interaction term with more than 2 variables: {term}. Only two-way interactions are allowed." + ) + + if len(interaction_terms) > 1: + raise FormulaException( + f"Formula contains more than 1 interaction term: {interaction_terms}. Maximum of 1 allowed." + ) + def summary(self, round_to=None) -> None: """Print summary of main results and model coefficients. diff --git a/causalpy/tests/test_input_validation.py b/causalpy/tests/test_input_validation.py index 43fd9208..69ca3753 100644 --- a/causalpy/tests/test_input_validation.py +++ b/causalpy/tests/test_input_validation.py @@ -30,18 +30,29 @@ def test_did_validation_post_treatment_formula(): - """Test that we get a FormulaException if do not include post_treatment in the - formula""" + """Test that we get a FormulaException for invalid formulas and missing post_treatment variables""" df = pd.DataFrame( { "group": [0, 0, 1, 1], "t": [0, 1, 0, 1], "unit": [0, 0, 1, 1], "post_treatment": [0, 1, 0, 1], + "male": [0, 1, 0, 1], # Additional variable for testing "y": [1, 2, 3, 4], } ) + df_with_custom = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "custom_post": [0, 1, 0, 1], # Custom column name + "y": [1, 2, 3, 4], + } + ) + + # Test 1: Missing post_treatment variable in formula with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Missing post_treatment variable in formula (duplicate test) with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 3: Custom post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df_with_custom, + formula="y ~ 1 + group*post_treatment", # Formula uses 'post_treatment' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # But user specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 4: Default post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + # post_treatment_variable_name defaults to "post_treatment" + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 5: Repeated interaction terms (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 6: Three-way interactions using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 7: Three-way interactions using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 8: Multiple different interaction terms using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 9: Multiple different interaction terms using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment + group:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_post_treatment_data(): """Test that we get a DataException if do not include post_treatment in the data""" @@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Custom post_treatment_variable_name but column doesn't exist in data + df_with_post = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "post_treatment": [0, 1, 0, 1], # Data has 'post_treatment' + "y": [1, 2, 3, 4], + } + ) + + with pytest.raises(DataException): + _ = cp.DifferenceInDifferences( + df_with_post, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # User specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_unit_data(): """Test that we get a DataException if do not include unit in the data""" diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 4704ef6c..08c36d5e 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 95.5% + interrogate: 92.6% - + @@ -12,8 +12,8 @@ interrogate interrogate - 95.5% - 95.5% + 92.6% + 92.6%