Having finished the prompt engineering with Llama2 course, I realised that Llama Guard could be exactly what I’m looking for. I work on child safeguarding projects for schools, but I don’t have labelled data for some topics (eg self-harm and bullying). I have captured text, but they aren’t labeled. So I thought I would use Llama Guard as a first pass (there are large numbers of false positives based on keyword lists). However, I can’t get it to work except for single cases. I’m working on DataBricks. If I generate distributed calls using Fugue it always says I have run out of GPU memory. This is using the ‘moderate_with_template()’ method used in the course and the model loaded from hugging face.
captures_table = spark.sql('''SELECT capture_pk,keyword_context FROM data_vault_dev.dbt_vault.sat_capture
WHERE policy = 'Self-harm' AND model_prediction = -1 LIMIT 10''')
def moderate_with_template(chat):
input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
prompt_len = input_ids.shape[-1]
return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
def predict(df:pd.DataFrame) -> pd.DataFrame:
chat = df.keyword_context.apply(lambda x: {"role": "user", "content": x.replace("\n", " ")} )
verdict = moderate_with_template(chat)
df['model_prediction'] = verdict
return df
result = transform(
captures_table,
predict,
schema="*+model_prediction:str",
params=dict(model=model),
engine=spark
)
gives OOM. While the following works, but is of course slow:
captures = captures_table.toPandas()
texts = pd.Series([[{"role": "user", "content": context[1]},] for context in captures.keyword_context])
print(texts)
results = texts.apply(lambda x: moderate_with_template(x))
Apart from being slow, this Series doesn’t give me the required association between the verdict and the text.