Using Llama Guard at scale

Having finished the prompt engineering with Llama2 course, I realised that Llama Guard could be exactly what I’m looking for. I work on child safeguarding projects for schools, but I don’t have labelled data for some topics (eg self-harm and bullying). I have captured text, but they aren’t labeled. So I thought I would use Llama Guard as a first pass (there are large numbers of false positives based on keyword lists). However, I can’t get it to work except for single cases. I’m working on DataBricks. If I generate distributed calls using Fugue it always says I have run out of GPU memory. This is using the ‘moderate_with_template()’ method used in the course and the model loaded from hugging face.

captures_table = spark.sql('''SELECT capture_pk,keyword_context FROM data_vault_dev.dbt_vault.sat_capture 
                           WHERE policy = 'Self-harm' AND model_prediction = -1 LIMIT 10''')

def moderate_with_template(chat):
    input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
    output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
    prompt_len = input_ids.shape[-1]
    return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

def predict(df:pd.DataFrame) -> pd.DataFrame:
  chat = df.keyword_context.apply(lambda x:  {"role": "user", "content": x.replace("\n", " ")} )
  verdict = moderate_with_template(chat)
  df['model_prediction'] = verdict
  return df

result = transform(
  captures_table,
  predict,
  schema="*+model_prediction:str",
  params=dict(model=model),
  engine=spark
)

gives OOM. While the following works, but is of course slow:

captures = captures_table.toPandas()

texts = pd.Series([[{"role": "user", "content": context[1]},] for context in captures.keyword_context])
print(texts)
results = texts.apply(lambda x: moderate_with_template(x))

Apart from being slow, this Series doesn’t give me the required association between the verdict and the text.

1 Like

Hello @Alun_ap_Rhisiart

Could you share screenshot of this error??

Regards
DP

Hi. Thanks for the response. I have solved the second part, the alternating problem. It was because I had an array of dictionaries, and iterated over it to send one at a time, but actually each item needed to be an array containing the single dictionary. I’ve updated the above with the working version. But what I need is to use either pandas_udfs or (preferably) fugue to that I run over the dataframe and store the result back into the dataframe on the same row, and of course for it to run at a decent speed.

1 Like