diff --git a/skills/us-cpa/src/us_cpa/questions.py b/skills/us-cpa/src/us_cpa/questions.py index 3ee1502..1a01209 100644 --- a/skills/us-cpa/src/us_cpa/questions.py +++ b/skills/us-cpa/src/us_cpa/questions.py @@ -63,6 +63,28 @@ def _filing_status_label(status: str) -> str: return status.replace("_", " ").title() +FILING_STATUS_PATTERNS = ( + (("married filing jointly",), "married_filing_jointly"), + (("mfj",), "married_filing_jointly"), + (("head of household",), "head_of_household"), + (("hoh",), "head_of_household"), + (("married filing separately",), "married_filing_separately"), + (("mfs",), "married_filing_separately"), + (("single",), "single"), +) + + +def _infer_filing_status(normalized_question: str, case_facts: dict[str, Any]) -> str: + if "filingStatus" in case_facts: + return case_facts["filingStatus"] + + for patterns, filing_status in FILING_STATUS_PATTERNS: + if all(pattern in normalized_question for pattern in patterns): + return filing_status + + return "single" + + @dataclass class QuestionEngine: corpus: TaxYearCorpus @@ -102,7 +124,7 @@ class QuestionEngine: if all(keyword in normalized for keyword in rule["keywords"]): authorities = self._authorities_for(manifest, rule["authority_slugs"]) if rule["issue"] == "standard_deduction": - filing_status = case_facts.get("filingStatus", "single") + filing_status = _infer_filing_status(normalized, case_facts) answer = rule["answer_by_status"].get(filing_status, rule["answer_by_status"]["single"]) summary = rule["summary_template"].format( filing_status_label=_filing_status_label(filing_status), diff --git a/skills/us-cpa/tests/test_questions.py b/skills/us-cpa/tests/test_questions.py index f464488..c5a6ebf 100644 --- a/skills/us-cpa/tests/test_questions.py +++ b/skills/us-cpa/tests/test_questions.py @@ -37,6 +37,34 @@ class QuestionEngineTests(unittest.TestCase): self.assertTrue(analysis["authorities"]) self.assertEqual(analysis["authorities"][0]["sourceClass"], "irs_instructions") + def test_standard_deduction_infers_married_filing_jointly_from_question(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + engine = self.build_engine(temp_dir) + + analysis = engine.answer( + question="What is the standard deduction for married filing jointly?", + tax_year=2025, + case_facts={}, + ) + + self.assertEqual(analysis["issue"], "standard_deduction") + self.assertEqual(analysis["conclusion"]["answer"], "$31,500") + self.assertIn("Married Filing Jointly", analysis["conclusion"]["summary"]) + + def test_standard_deduction_infers_head_of_household_from_question(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + engine = self.build_engine(temp_dir) + + analysis = engine.answer( + question="What is the standard deduction for a head of household filer?", + tax_year=2025, + case_facts={}, + ) + + self.assertEqual(analysis["issue"], "standard_deduction") + self.assertEqual(analysis["conclusion"]["answer"], "$23,625") + self.assertIn("Head Of Household", analysis["conclusion"]["summary"]) + def test_complex_question_flags_primary_law_escalation(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: engine = self.build_engine(temp_dir)