fix: infer filing status from us-cpa questions
This commit is contained in:
@@ -63,6 +63,28 @@ def _filing_status_label(status: str) -> str:
|
||||
return status.replace("_", " ").title()
|
||||
|
||||
|
||||
FILING_STATUS_PATTERNS = (
|
||||
(("married filing jointly",), "married_filing_jointly"),
|
||||
(("mfj",), "married_filing_jointly"),
|
||||
(("head of household",), "head_of_household"),
|
||||
(("hoh",), "head_of_household"),
|
||||
(("married filing separately",), "married_filing_separately"),
|
||||
(("mfs",), "married_filing_separately"),
|
||||
(("single",), "single"),
|
||||
)
|
||||
|
||||
|
||||
def _infer_filing_status(normalized_question: str, case_facts: dict[str, Any]) -> str:
|
||||
if "filingStatus" in case_facts:
|
||||
return case_facts["filingStatus"]
|
||||
|
||||
for patterns, filing_status in FILING_STATUS_PATTERNS:
|
||||
if all(pattern in normalized_question for pattern in patterns):
|
||||
return filing_status
|
||||
|
||||
return "single"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionEngine:
|
||||
corpus: TaxYearCorpus
|
||||
@@ -102,7 +124,7 @@ class QuestionEngine:
|
||||
if all(keyword in normalized for keyword in rule["keywords"]):
|
||||
authorities = self._authorities_for(manifest, rule["authority_slugs"])
|
||||
if rule["issue"] == "standard_deduction":
|
||||
filing_status = case_facts.get("filingStatus", "single")
|
||||
filing_status = _infer_filing_status(normalized, case_facts)
|
||||
answer = rule["answer_by_status"].get(filing_status, rule["answer_by_status"]["single"])
|
||||
summary = rule["summary_template"].format(
|
||||
filing_status_label=_filing_status_label(filing_status),
|
||||
|
||||
@@ -37,6 +37,34 @@ class QuestionEngineTests(unittest.TestCase):
|
||||
self.assertTrue(analysis["authorities"])
|
||||
self.assertEqual(analysis["authorities"][0]["sourceClass"], "irs_instructions")
|
||||
|
||||
def test_standard_deduction_infers_married_filing_jointly_from_question(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
engine = self.build_engine(temp_dir)
|
||||
|
||||
analysis = engine.answer(
|
||||
question="What is the standard deduction for married filing jointly?",
|
||||
tax_year=2025,
|
||||
case_facts={},
|
||||
)
|
||||
|
||||
self.assertEqual(analysis["issue"], "standard_deduction")
|
||||
self.assertEqual(analysis["conclusion"]["answer"], "$31,500")
|
||||
self.assertIn("Married Filing Jointly", analysis["conclusion"]["summary"])
|
||||
|
||||
def test_standard_deduction_infers_head_of_household_from_question(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
engine = self.build_engine(temp_dir)
|
||||
|
||||
analysis = engine.answer(
|
||||
question="What is the standard deduction for a head of household filer?",
|
||||
tax_year=2025,
|
||||
case_facts={},
|
||||
)
|
||||
|
||||
self.assertEqual(analysis["issue"], "standard_deduction")
|
||||
self.assertEqual(analysis["conclusion"]["answer"], "$23,625")
|
||||
self.assertIn("Head Of Household", analysis["conclusion"]["summary"])
|
||||
|
||||
def test_complex_question_flags_primary_law_escalation(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
engine = self.build_engine(temp_dir)
|
||||
|
||||
Reference in New Issue
Block a user