Getting started with training and sampling¶
In this guide, we'll step you through using the logits Python SDK to do the basic operations needed for training and sampling on the Logits platform.
View the complete Python script →
Creating the training client¶
The main object we'll be using is the TrainingClient, which corresponds to a fine-tuned model that we can train and sample from.
First, set your Logits API key environment variable. In the terminal where you'll run Python, or in your .bashrc, put export LOGITS_API_KEY=<your key>.
Then, create a ServiceClient. This lets you find out what base models are available to be fine-tuned.
import logits
service_client = logits.ServiceClient()
print("Available models:")
for item in service_client.get_server_capabilities().supported_models:
print("- " + item.model_name)
- Qwen/Qwen3.5-4B
- Qwen/Qwen3.5-9B
Qwen/Qwen3-8B for these examples. See Available Models in Logits for the full list.
Now we can create the TrainingClient:
base_model = "Qwen/Qwen3.5-4B"
training_client = service_client.create_lora_training_client(
base_model=base_model
)
Preparing the training data¶
Now we can do training updates on the model. This quickstart example won't show best practices for LLM fine-tuning; it's just an API demo. Check out Rendering, Supervised Fine-tuning, and the other Cookbook examples for guidance on how to use Logits in real applications.
For this model, we'll train a model that can translate words into Pig Latin. The rules for Pig Latin are simple: - If a word begins with a consonant, move it to the end and add "ay" - If a word begins with a vowel, just add "way" to the end
Here are some example completions we'd like the model to perform, where the prompt is in green and the model's completion is in red:
Pig Latin: ello-hay orld-way
Let's create some training examples and convert them to the format expected by the backend.
# Create some training examples
examples = [
{
"input": "banana split",
"output": "anana-bay plit-say"
},
{
"input": "quantum physics",
"output": "uantum-qay ysics-phay"
},
{
"input": "donut shop",
"output": "onut-day op-shay"
},
{
"input": "pickle jar",
"output": "ickle-pay ar-jay"
},
{
"input": "space exploration",
"output": "ace-spay exploration-way"
},
{
"input": "rubber duck",
"output": "ubber-ray uck-day"
},
{
"input": "coding wizard",
"output": "oding-cay izard-way"
},
]
# Convert examples into the format expected by the training client
from logits import types
# Get the tokenizer from the training client
tokenizer = training_client.get_tokenizer()
def process_example(example: dict, tokenizer) -> types.Datum:
# Format the input with Input/Output template
# For most real use cases, you'll want to use a renderer / chat template,
# (see later docs) but here, we'll keep it simple.
prompt = f"English: {example['input']}\nPig Latin:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
# Add a space before the output string, and finish with double newline
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
weights = weights[1:]
# A datum is a single training example for the loss function.
# It has model_input, which is the input sequence that'll be passed into the LLM,
# loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)
processed_examples = [process_example(ex, tokenizer) for ex in examples]
# Visualize the first example for debugging purposes
datum0 = processed_examples[0]
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
print("-" * 50)
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")
The visualization of the first example is:
Input Target Weight
--------------------------------------------------
'English' ':' 0.0
':' ' banana' 0.0
' banana' ' split' 0.0
' split' '\n' 0.0
'\n' 'P' 0.0
'P' 'ig' 0.0
'ig' ' Latin' 0.0
' Latin' ':' 0.0
':' ' an' 1.0
' an' 'ana' 1.0
'ana' '-b' 1.0
'-b' 'ay' 1.0
'ay' ' pl' 1.0
' pl' 'it' 1.0
'it' '-s' 1.0
'-s' 'ay' 1.0
'ay' '\n\n' 1.0
Performing a training update¶
Now we can use this data to perform a training update. We'll do 6 updates on the same batch of data. (Note that this is not typically a good way to train!)
import numpy as np
for _ in range(6):
fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
# Wait for the results
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()
# fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
# average log loss per token.
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")
Note that the forward_backward and optim_step functions immediately return futures, which acknowledge that the task has been queued up by the server. For improved speed, we submitted both operations before waiting for the result by calling result() on the futures.
Sampling from the model¶
Now we can test our model by sampling from it. In this case, we'll translate the phrase "coffee break" into Pig Latin.
# First, create a sampling client. We need to transfer weights
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')
# Now, we can sample from the model.
prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"]) # Greedy sampling
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
result = future.result()
print("Responses:")
for i, seq in enumerate(result.sequences):
print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
Since sampling is nondeterministic (sadly, even with temperature=0.0, due to batching), the output will be different each time. You should see something like this:
Responses:
0: ' offe-bay eak-bay\n\n'
1: ' offey-coy eak-bray\n\n'
2: ' offecay eakbray\n\n'
3: ' offeec-cay eak-brcay\n\n\n'
4: ' offecay akebay\n\n'
5: ' offee-Cay ake-bay\n\n\n'
6: ' offey-pay eak-bray\n\n'
7: ' offee – cay eak – bray\n\n'
Computing logprobs for a sequence¶
We can use the sampler to compute logprobs for a given sequence as well. This uses the prefill step and is returned as prompt logprobs.
import logits
prompt = types.ModelInput.from_ints(tokenizer.encode("How many r's are in the word strawberry?"))
sample_response = sampling_client.sample(
prompt=prompt,
num_samples=1,
sampling_params=logits.SamplingParams(max_tokens=1), # Must be at least 1 token, represents prefill step
include_prompt_logprobs=True,
).result()
# example: [None, -9.54505, -1.64629, -8.81116, -3.50217, -8.25927, ...]
print(sample_response.prompt_logprobs)
The first logprob is None (corresponding to the first token), and subsequent entries are logprobs of each token in the prompt.
The sampling client also has a helper function, which is the same as above:
sampling_client.compute_logprobs(prompt).result()
Top-k logprobs¶
For distillation, it may be especially useful to compute top-k logprobs for each token as well, which can get you a sense for what the model "would have said" after each prefix instead of the actual prompt.
import logits
sample_response = sampling_client.sample(
prompt=prompt,
num_samples=1,
sampling_params=logits.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
topk_prompt_logprobs=5,
).result()
# example:
# [None,
# [(14924, -1.17005), (755, -2.23255), (2, -2.73255), (791, -3.67005), (16309, -4.29505)],
# [(25, -1.64629), (3137, -2.39629), (11630, -2.89629), (21460, -3.83379), (14881, -4.02129)],
# [(41, -3.49866), (42, -3.49866), (49, -4.24866), (38, -4.37366), (54, -4.49866)],
# [(311, -1.00217), (656, -2.25217), (2057, -2.75217), (649, -3.25217), (10470, -3.37717)],
# ...]
sample_response.topk_prompt_logprobs
For each position in the response, this returns a list of (token_id, logprob) pairs for the top-k most likely tokens at that position.