Skip to content

Commit 9bec20c

Browse files
committed
add SklearnPredictionTypes
1 parent 3bdb443 commit 9bec20c

File tree

7 files changed

+58
-44
lines changed

7 files changed

+58
-44
lines changed

vetiver/handlers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def handler_startup():
121121
"""
122122
...
123123

124-
def handler_predict(self, input_data, check_prototype):
124+
def handler_predict(self, input_data, check_prototype, **kw):
125125
"""Generates method for /predict endpoint in VetiverAPI
126126
127127
The `handler_predict` function executes at each API call. Use this

vetiver/handlers/spacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def construct_prototype(self):
5353

5454
return prototype
5555

56-
def handler_predict(self, input_data, check_prototype):
56+
def handler_predict(self, input_data, check_prototype, **kw):
5757
"""
5858
Generates method for /predict endpoint in VetiverAPI
5959

vetiver/handlers/statsmodels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class StatsmodelsHandler(BaseHandler):
2222
if sm_exists:
2323
pip_name = "statsmodels"
2424

25-
def handler_predict(self, input_data, check_prototype):
25+
def handler_predict(self, input_data, check_prototype, **kw):
2626
"""
2727
Generates method for /predict endpoint in VetiverAPI
2828

vetiver/handlers/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TorchHandler(BaseHandler):
2222
if torch_exists:
2323
pip_name = "torch"
2424

25-
def handler_predict(self, input_data, check_prototype):
25+
def handler_predict(self, input_data, check_prototype, **kw):
2626
"""
2727
Generates method for /predict endpoint in VetiverAPI
2828

vetiver/handlers/xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class XGBoostHandler(BaseHandler):
2222
if xgb_exists:
2323
pip_name = "xgboost"
2424

25-
def handler_predict(self, input_data, check_prototype):
25+
def handler_predict(self, input_data, check_prototype, **kw):
2626
"""
2727
Generates method for /predict endpoint in VetiverAPI
2828

vetiver/server.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
import pandas as pd
1212
import requests
1313
import uvicorn
14-
from fastapi import FastAPI, Request
14+
from fastapi import FastAPI
1515
from fastapi.exceptions import RequestValidationError
1616
from fastapi.openapi.utils import get_openapi
1717
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
18-
1918
from .helpers import api_data_to_frame, response_to_frame
19+
from .handlers.sklearn import SKLearnHandler
2020
from .meta import VetiverMeta
2121
from .utils import _jupyter_nb, get_workbench_path
2222
from .vetiver_model import VetiverModel
23+
from .types import SklearnPredictionTypes
2324

2425

2526
class VetiverAPI:
@@ -111,7 +112,6 @@ async def startup_event():
111112

112113
@app.get("/", include_in_schema=False)
113114
def docs_redirect():
114-
115115
redirect = "__docs__"
116116

117117
return RedirectResponse(redirect)
@@ -200,65 +200,75 @@ async def validation_exception_handler(request, exc):
200200

201201
return app
202202

203-
def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
204-
"""Create new POST endpoint that is aware of model input data
203+
def vetiver_post(
204+
self,
205+
endpoint_fx: Union[Callable, SklearnPredictionTypes],
206+
endpoint_name: str = None,
207+
**kw,
208+
):
209+
"""Define a new POST endpoint that utilizes the model's input data.
205210
206211
Parameters
207212
----------
208-
endpoint_fx : typing.Callable
209-
Custom function to be run at endpoint
213+
endpoint_fx : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
214+
A callable function that specifies the custom logic to execute when the endpoint is called.
215+
This function should take input data (e.g., a DataFrame or dictionary) and return the desired output
216+
(e.g., predictions or transformed data). For scikit-learn models, endpoint_fx can also be one of
217+
"predict", "predict_proba", or "predict_log_proba" if the model supports these methods.
218+
210219
endpoint_name : str
211-
Name of endpoint
220+
The name of the endpoint to be created.
212221
213222
Examples
214223
-------
215-
```{python}
224+
```python
216225
from vetiver import mock, VetiverModel, VetiverAPI
217226
X, y = mock.get_mock_data()
218227
model = mock.get_mock_model().fit(X, y)
219228
220-
v = VetiverModel(model = model, model_name = "model", prototype_data = X)
221-
v_api = VetiverAPI(model = v, check_prototype = True)
229+
v = VetiverModel(model=model, model_name="model", prototype_data=X)
230+
v_api = VetiverAPI(model=v, check_prototype=True)
222231
223232
def sum_values(x):
224233
return x.sum()
234+
225235
v_api.vetiver_post(sum_values, "sums")
226236
```
227237
"""
228-
if not endpoint_name:
229-
endpoint_name = endpoint_fx.__name__
230238

231-
if endpoint_fx.__doc__ is not None:
232-
api_desc = dedent(endpoint_fx.__doc__)
233-
else:
234-
api_desc = None
235-
236-
if self.check_prototype is True:
237-
238-
@self.app.post(
239-
urljoin("/", endpoint_name),
240-
name=endpoint_name,
241-
description=api_desc,
239+
if isinstance(endpoint_fx, SklearnPredictionTypes):
240+
if not isinstance(self.model, SKLearnHandler):
241+
raise ValueError(
242+
"The 'endpoint_fx' parameter can only be a string when using scikit-learn models."
243+
)
244+
self.vetiver_post(
245+
self.model.handler_predict,
246+
SklearnPredictionTypes,
247+
check_prototype=self.check_prototype,
248+
prediction_type=endpoint_fx,
242249
)
243-
async def custom_endpoint(input_data: List[self.model.prototype]):
244-
_to_frame = api_data_to_frame(input_data)
245-
predictions = endpoint_fx(_to_frame, **kw)
246-
if isinstance(predictions, List):
247-
return {endpoint_name: predictions}
248-
else:
249-
return predictions
250+
return
250251

251-
else:
252+
endpoint_name = endpoint_name or endpoint_fx.__name__
253+
endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None
252254

253-
@self.app.post(urljoin("/", endpoint_name))
254-
async def custom_endpoint(input_data: Request):
255+
@self.app.post(
256+
urljoin("/", endpoint_name),
257+
name=endpoint_name,
258+
description=endpoint_doc,
259+
)
260+
async def custom_endpoint(input_data: List[self.model.prototype]):
261+
if self.check_prototype:
262+
served_data = api_data_to_frame(input_data)
263+
else:
255264
served_data = await input_data.json()
256-
predictions = endpoint_fx(served_data, **kw)
257265

258-
if isinstance(predictions, List):
259-
return {endpoint_name: predictions}
260-
else:
261-
return predictions
266+
predictions = endpoint_fx(served_data, **kw)
267+
268+
if isinstance(predictions, List):
269+
return {endpoint_name: predictions}
270+
else:
271+
return predictions
262272

263273
def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw):
264274
"""

vetiver/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pydantic import BaseModel, create_model
2+
from typing import Literal
23

34
all = ["Prototype", "create_prototype"]
45

@@ -7,5 +8,8 @@ class Prototype(BaseModel):
78
pass
89

910

11+
SklearnPredictionTypes = Literal["predict", "predict_proba", "predict_log_proba"]
12+
13+
1014
def create_prototype(**dict_data):
1115
return create_model("prototype", __base__=Prototype, **dict_data)

0 commit comments

Comments
 (0)