Skip to content

Commit dbea604

Browse files
authored
Add input validation for structured-data-classification (#97)
1 parent cc91e1d commit dbea604

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

api-inference-community/api_inference_community/validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,18 @@ def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]):
132132
raise ValueError("All rows in the table must be the same length")
133133

134134

135+
class StructuredDataClassificationInputsCheck(BaseModel):
136+
data: Dict[str, List[str]]
137+
138+
@validator("data")
139+
def all_rows_must_have_same_length(cls, data: Dict[str, List[str]]):
140+
rows = list(data.values())
141+
n = len(rows[0])
142+
if all(len(x) == n for x in rows):
143+
return data
144+
raise ValueError("All rows in the data must be the same length")
145+
146+
135147
class StringOrStringBatchInputCheck(BaseModel):
136148
__root__: Union[List[str], str]
137149

@@ -164,6 +176,7 @@ class StringInput(BaseModel):
164176
"feature-extraction": StringOrStringBatchInputCheck,
165177
"sentence-similarity": SentenceSimilarityInputsCheck,
166178
"table-question-answering": TableQuestionAnsweringInputsCheck,
179+
"structured-data-classification": StructuredDataClassificationInputsCheck,
167180
"fill-mask": StringInput,
168181
"summarization": StringInput,
169182
"text2text-generation": StringInput,

api-inference-community/tests/test_nlp.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ def test_no_query(self):
152152
normalize_payload_nlp(bpayload, "table-question-answering")
153153

154154

155+
class StructuredDataClassificationValidationTestCase(TestCase):
156+
def test_valid_input(self):
157+
data = {
158+
"Repository": ["Transformers", "Datasets", "Tokenizers"],
159+
"Stars": ["36542", "4512", "3934"],
160+
}
161+
162+
inputs = {"data": data}
163+
bpayload = json.dumps({"inputs": inputs}).encode("utf-8")
164+
normalized_inputs, processed_params = normalize_payload_nlp(
165+
bpayload, "structured-data-classification"
166+
)
167+
self.assertEqual(processed_params, {})
168+
self.assertEqual(inputs, normalized_inputs)
169+
170+
def test_invalid_data_lengths(self):
171+
data = {
172+
"Repository": ["Transformers", "Datasets", "Tokenizers"],
173+
"Stars": ["36542", "4512"],
174+
}
175+
176+
inputs = {"data": data}
177+
bpayload = json.dumps({"inputs": inputs}).encode("utf-8")
178+
with self.assertRaises(ValidationError):
179+
normalize_payload_nlp(bpayload, "structured-data-classification")
180+
181+
def test_invalid_data_type(self):
182+
inputs = {"data": "Invalid data"}
183+
bpayload = json.dumps({"inputs": inputs}).encode("utf-8")
184+
with self.assertRaises(ValidationError):
185+
normalize_payload_nlp(bpayload, "structured-data-classification")
186+
187+
155188
class SummarizationValidationTestCase(TestCase):
156189
def test_no_params(self):
157190
bpayload = json.dumps({"inputs": "whatever"}).encode("utf-8")

0 commit comments

Comments
 (0)