Ordinal output layers and loss functions (Rennie & Srebro, 2005) for PyTorch and TF/Keras
Project description
DeepOrdinal
Ordinal output layers and loss functions for PyTorch and TensorFlow/Keras, based on Rennie & Srebro (2005).
DeepOrdinal provides an OrdinalOutput layer that converts a learned logit into ordinal class probabilities via sorted thresholds, plus loss functions designed specifically for ordinal regression.
Installation
pip install deepordinal
Install with a specific backend:
pip install "deepordinal[torch]" # PyTorch
pip install "deepordinal[tf]" # TensorFlow/Keras
For development:
pip install -e ".[torch,tf]"
Quick Start
PyTorch
import torch
import torch.nn as nn
from deepordinal.torch import OrdinalOutput, ordinal_loss
model = nn.Sequential(
nn.Linear(8, 16),
nn.ReLU(),
OrdinalOutput(input_dim=16, output_dim=4),
)
ordinal_layer = model[2]
optimizer = torch.optim.Adam(model.parameters())
# Training step
h = model[1](model[0](x_batch)) # hidden features
logits = ordinal_layer.linear(h) # raw logit
thresholds = ordinal_layer.interior_thresholds
loss = ordinal_loss(logits, y_batch, thresholds)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Inference
probs = model(x_batch) # (batch, K) class probabilities
preds = probs.argmax(dim=1)
TensorFlow/Keras
import tensorflow as tf
from deepordinal.tf import OrdinalOutput, ordinal_loss
dense = tf.keras.layers.Dense(16, activation="relu")
ordinal = OrdinalOutput(output_dim=4)
# Training step
with tf.GradientTape() as tape:
h = dense(x_batch)
logits = tf.matmul(h, ordinal.kernel) + ordinal.bias
thresholds = ordinal.interior_thresholds[0] # shape is (1, K-1)
loss = ordinal_loss(logits, y_batch, thresholds)
grads = tape.gradient(loss, dense.trainable_variables + ordinal.trainable_variables)
# Inference
probs = ordinal(dense(x_batch)) # (batch, K) class probabilities
preds = tf.argmax(probs, axis=1)
API
OrdinalOutput Layer
Projects an input down to a single logit and converts it into K class probabilities using K-1 learned, sorted thresholds:
P(y = k | x) = sigmoid(t(k+1) - logit) - sigmoid(t(k) - logit)
where t(0) = -inf and t(K) = inf are fixed, and interior thresholds are initialized sorted.
| PyTorch | TensorFlow/Keras | |
|---|---|---|
| Import | from deepordinal.torch import OrdinalOutput |
from deepordinal.tf import OrdinalOutput |
| Constructor | OrdinalOutput(input_dim, output_dim) |
OrdinalOutput(output_dim) |
| Logit access | layer.linear(h) |
tf.matmul(h, layer.kernel) + layer.bias |
| Thresholds | layer.interior_thresholds — shape (K-1,) |
layer.interior_thresholds[0] — shape (1, K-1) |
ordinal_loss
Threshold-based ordinal loss from Rennie & Srebro (2005). Operates on raw logits and thresholds rather than probability output.
ordinal_loss(logits, targets, thresholds, construction="all", penalty="logistic")
Parameters:
- logits —
(batch,)or(batch, 1), raw predictor output - targets —
(batch,), integer labels in[0, K) - thresholds —
(K-1,), sorted interior thresholds - construction —
"all"(default) or"immediate" - penalty —
"logistic"(default),"hinge","smooth_hinge", or"modified_least_squares"
Returns: scalar mean loss over the batch.
Constructions
- All-threshold (default, eq 13) — penalizes violations of every threshold, weighted by direction. Bounds mean absolute error. Best performer in the paper's experiments.
- Immediate-threshold (eq 12) — only penalizes violations of the two thresholds bounding the correct class segment.
Penalty Functions
| Name | Formula | Reference |
|---|---|---|
"logistic" |
log(1 + exp(-z)) |
eq 9 |
"hinge" |
max(0, 1-z) |
eq 5 |
"smooth_hinge" |
0 if z>=1, (1-z)^2/2 if 0<z<1, 0.5-z if z<=0 |
eq 6 |
"modified_least_squares" |
0 if z>=1, (1-z)^2 if z<1 |
eq 7 |
The paper recommends all-threshold + logistic as the best-performing combination (the default).
ordistic_loss
Probabilistic generalization of logistic regression to K-class ordinal problems (Section 4).
ordistic_loss(logits, targets, means, log_priors=None)
Parameters:
- logits —
(batch,)or(batch, 1), raw predictor output - targets —
(batch,), integer labels in[0, K) - means —
(K,), class means (convention: mu_1=-1, mu_K=1; interior means learned) - log_priors —
(K,)orNone, optional log-prior terms
Returns: scalar mean negative log-likelihood over the batch.
Examples
Complete training loops with synthetic ordinal data:
examples/example_torch.ipynb— PyTorchexamples/example_tf.ipynb— TensorFlow/Keras
Testing
pip install -e ".[torch,tf]"
pytest -v
License
MIT
Citation
Rennie, J. D. M. & Srebro, N. (2005). Loss Functions for Preference Levels: Regression with Discrete Ordered Labels. Proceedings of the IJCAI Multidisciplinary Workshop on Advances in Preference Handling.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deepordinal-0.2.1.post1.tar.gz.
File metadata
- Download URL: deepordinal-0.2.1.post1.tar.gz
- Upload date:
- Size: 10.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c9f44c3e504331e232ece14d6576ddb3ab8872b6bd8da2a19228b6bcbfcebf9c
|
|
| MD5 |
80e714820fbdf470507ad2e440d68b00
|
|
| BLAKE2b-256 |
de02ff4c59b74c68125a5ad9b83b3a213399dec06a16a8872d225d89fe578b7a
|
File details
Details for the file deepordinal-0.2.1.post1-py3-none-any.whl.
File metadata
- Download URL: deepordinal-0.2.1.post1-py3-none-any.whl
- Upload date:
- Size: 8.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8a385fe6c781c5092da1bebcc43f5f5f2ba4ac6abdca11ba3eeb23bf408047ad
|
|
| MD5 |
df47f8d5808507d4b5c3d20ee6002295
|
|
| BLAKE2b-256 |
8bb7e90e98b844a2607f2ce08617c0309e72474c9dc8d7a715eb473acea87083
|