When creating a post, please add:
-
Week 3
-
Link to the classroom item you are referring to: Lab_3_fine_t… - JupyterLab
for this function: def build_dataset(model_name,
dataset_name,
input_min_text_length,
input_max_text_length):“”"
Preprocess the dataset and split it into train and test parts.Parameters:
- model_name (str): Tokenizer model name.
- dataset_name (str): Name of the dataset to load.
- input_min_text_length (int): Minimum length of the dialogues.
- input_max_text_length (int): Maximum length of the dialogues.
Returns:
- dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
“”"
load dataset (only “train” part will be enough for this lab).
dataset = load_dataset(dataset_name, split=“train”)
Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
dataset = dataset.filter(lambda x: len(x[“dialogue”]) > input_min_text_length and len(x[“dialogue”]) <= input_max_text_length, batched=False)
Prepare tokenizer. Setting device_map=“auto” allows to switch between GPU and CPU automatically.
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map=“auto”)
def tokenize(sample):
# Wrap each dialogue with the instruction. prompt = f"""
Summarize the following conversation.
{sample[“dialogue”]}
Summary:
“”"
sample[“input_ids”] = tokenizer.encode(prompt)
# This must be called "query", which is a requirement of our PPO library.
sample["query"] = tokenizer.decode(sample["input_ids"])
return sample
# Tokenize each dialogue.
dataset = dataset.map(tokenize, batched=False)
dataset.set_format(type="torch")
# Split the dataset into train and test parts.
dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)
return dataset_splits
dataset = build_dataset(model_name=model_name,
dataset_name=huggingface_dataset_name,
input_min_text_length=200,
input_max_text_length=1000)
print(dataset) , iam getting error like this:TimeoutError: _ssl.c:989: The handshake operation timed out
The above exception was the direct cause of the following exception: ReadTimeout: (ReadTimeoutError(“HTTPSConnectionPool(host=‘cas-bridge.xethub.hf.co’, port=443): Read timed out. (read timeout=10)”), ‘(Request ID: 25c3654c-018d-4099-8b03-b517fe4187c5)’)