I have installed Trax on my Ubuntu machine, and trained the model on 1 core. However, when I add a new parameter to training_loop - n_devices = 12 (see below), it throws a JAX error: “JAX cannot work yet with n_devices != all devices: 12 != 1”, which implies that JAX only gets one core. Could I ask for your help in order to expose my entire CPU to JAX/Trax?
With kind regards,
Jaan Übi
training_loop = training.Loop(
NER, # A model to train
train_task, # A train task
eval_tasks = [eval_task], # The evaluation task
output_dir = output_dir, # The output directory
n_devices = 11
)
Once a device is chosen, the available cores will still be used for vectorizing computations.
So, the CPU cores on your laptop are being used even when n_devices is not set by you. Open system monitor and see for yourself by setting train_steps to be a number like 500. Be sure to close other applications and run the python file from command line.
I have a follow up questions. if we want to train an intermediate size model (data size in 20 Mbytes range, model complexity like LSTM) with TRAX, would a single computer with i5 CPU 6 cores be possible to handle it? how do we determine the system spec required? Many thanks.
I’m no expert on this so what I usually do is I run some test runs and check the memory footprint and also the time to complete an epoch to get a rough estimation how long will the model will train that is pretty lame “strategy” but it is what it is…
From what I understand from your post is that you should not have problems fitting the model into memory (which is usually the biggest problem in general) but the training time should be quite long (slow training).
I’m not sure if you have the access to Generative AI course lecture Computational challenges of training LLMs (if you don’t, it doesn’t matter - the main point is discussed here, and btw I also think that the calculations are off but a follow up might reveal the truth).
Again, I’m no expert, but in my experience, small/intermediate models (less than 300M parameters) are not a problem to train on GPUs, but for CPUs I think that would take ages.