File size: 5,566 Bytes
f228a1c
 
 
 
99351b6
338f4c1
 
c311b0d
f228a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
338f4c1
 
f228a1c
14c8502
 
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
 
14c8502
 
 
 
 
 
99351b6
 
 
 
 
 
 
 
f43f094
 
 
 
5c19b8d
 
 
 
 
 
 
 
 
4fa87d4
5c19b8d
 
 
4fa87d4
5c19b8d
 
 
 
4fa87d4
 
 
 
5c19b8d
 
 
 
 
c311b0d
 
 
 
 
 
 
 
 
 
5c19b8d
 
 
 
 
99351b6
338f4c1
99351b6
338f4c1
f43f094
99351b6
 
f228a1c
bc21776
338f4c1
f228a1c
338f4c1
14c8502
 
 
 
f228a1c
 
 
5c19b8d
4fa87d4
5c19b8d
4fa87d4
5c19b8d
 
 
 
 
 
 
 
c311b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f228a1c
 
5c19b8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
from typing import List, Optional

class Guardrail:
    def __init__(self):
        tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
        self.classifier = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            truncation=True,
            max_length=512,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )

    async def guard(self, prompt):
        return await run_in_threadpool(self.classifier, prompt)

    def determine_level(self, label, score):
        if label == "SAFE":
            return 0, "safe"
        else:
            if score > 0.9:
                return 4, "high"
            elif score > 0.75:
                return 3, "medium"
            elif score > 0.5:
                return 2, "low"
            else:
                return 1, "very low"

class TextPrompt(BaseModel):
    prompt: str

class ClassificationResult(BaseModel):
    label: str
    score: float
    level: int
    severity_label: str

class ToxicityResult(BaseModel):
    toxicity: float
    severe_toxicity: float
    obscene: float
    threat: float
    insult: float
    identity_attack: float

    @classmethod
    def from_dict(cls, data: dict):
        return cls(**{k: float(v) for k, v in data.items()})

class TopicBannerClassifier:
    def __init__(self):
        self.classifier = pipeline(
            "zero-shot-classification",
            model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.hypothesis_template = "This text is about {}"

    async def classify(self, text, labels):
        return await run_in_threadpool(
            self.classifier,
            text,
            labels,
            hypothesis_template=self.hypothesis_template,
            multi_label=False
        )

class TopicBannerRequest(BaseModel):
    prompt: str
    labels: List[str]

class TopicBannerResult(BaseModel):
    sequence: str
    labels: list
    scores: list

class GuardrailsRequest(BaseModel):
    prompt: str
    guardrails: List[str]
    labels: Optional[List[str]] = None

class GuardrailsResponse(BaseModel):
    prompt_injection: Optional[ClassificationResult] = None
    toxicity: Optional[ToxicityResult] = None
    topic_banner: Optional[TopicBannerResult] = None

app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()

@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
    try:
        result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
        return ToxicityResult.from_dict(result)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
    try:
        result = await guardrail.guard(text_prompt.prompt)
        label = result[0]['label']
        score = result[0]['score']
        level, severity_label = guardrail.determine_level(label, score)
        return {"label": label, "score": score, "level": level, "severity_label": severity_label}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(request: TopicBannerRequest):
    try:
        result = await topic_banner_classifier.classify(request.prompt, request.labels)
        return {
            "sequence": result["sequence"],
            "labels": result["labels"],
            "scores": result["scores"]
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/guardrails", response_model=GuardrailsResponse)
async def evaluate_guardrails(request: GuardrailsRequest):
    tasks = []
    response = GuardrailsResponse()

    if "pi" in request.guardrails:
        tasks.append(classify_text(TextPrompt(prompt=request.prompt)))
    if "tox" in request.guardrails:
        tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt)))
    if "top" in request.guardrails:
        if not request.labels:
            raise HTTPException(status_code=400, detail="Labels are required for topic banner classification")
        tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels)))

    results = await asyncio.gather(*tasks, return_exceptions=True)

    for result, guardrail in zip(results, request.guardrails):
        if isinstance(result, Exception):
            # Handle the exception as needed
            continue
        if guardrail == "pi":
            response.prompt_injection = result
        elif guardrail == "tox":
            response.toxicity = result
        elif guardrail == "top":
            response.topic_banner = result

    return response

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)