Extensions#
Once you have a working baseline using precomputed features, frozen backbones, and a simple triplet loss function, you can start exploring various extensions and improvements to your sketch-photo retrieval system.
One option is to replace the loss function with more advanced alternatives. Some of them were mentioned earlier, such as NT-Xent loss.
Another direction is to unfreeze the backbone and fine-tune the entire network end-to-end, allowing the feature extractor to adapt specifically to the sketch-photo domain.
You can also experiment with data augmentation, different backbone architectures, deeper embedding heads, or using category-balanced batch sampling to improve training stability and retrieval accuracy.
Finally, you may consider shifting from category-level retrieval (finding photos from the same class) to instance-level retrieval, where the goal is to retrieve the exact photo a sketch was drawn from. This task is more challenging and may require several modifications to the training process. It may be a good idea to read some research papers on this topic, starting with the one that introduced the Sketchy dataset: The sketchy database - Learning to retrieve badly drawn bunnies.
As you explore these directions, keep in mind that a good evaluation setup is essential. Without reliable metrics and a consistent retrieval protocol, it becomes difficult to assess the impact of your changes.
Sketch drawing#
To make the retrieval system more accessible and engaging, you can add an interactive sketch interface that allows users to draw a query sketch directly in the browser or notebook. Instead of selecting a pre-existing sketch from the dataset, the user draws their own sketch, and the system immediately returns the most visually similar photos from the gallery. This interface is useful for demos, testing, and exploring how the model handles different drawing styles and levels of abstraction. It also provides a direct, intuitive way to evaluate retrieval quality beyond scripted queries.
Implementation#
Copy the following code into a new file called widget.py
. You must also install the anywidget
package.
import anywidget, traitlets
import numpy as np
class ByteBuffer(traitlets.TraitType):
"""Accept bytes *or* memoryview; always store as bytes."""
info_text = "bytes or memoryview"
default_value = b""
def validate(self, obj, value):
if isinstance(value, (bytes, bytearray)):
return bytes(value) # ensure immutable bytes
if isinstance(value, memoryview):
return value.tobytes() # convert once, then store
self.error(obj, value) # otherwise raise
class SketchCanvas(anywidget.AnyWidget):
"""Draw canvas that returns a 256×256 NumPy array (0/1)."""
_sketch = ByteBuffer().tag(sync=True)
def add_listener(self, callback):
"""Add a callback to be called when the sketch is updated."""
self.observe(callback, "_sketch")
def to_array(self):
"""Convert the sketch to a NumPy array."""
arr = np.frombuffer(self._sketch, dtype=np.uint8)
arr = arr.reshape((256, 256, 4))[:, :, :3]
return arr
_esm = """
// Front‑end code (executed inside the widget area)
export function render({ model, el }) {
const SIZE = 256;
el.innerHTML = `
<canvas id="draw" width="${SIZE}" height="${SIZE}" style="border:1px solid #666;border-radius:8px;"></canvas>
<div style="margin-top:8px">
<button id="clear">Clear</button>
<button id="send">Send</button>
</div>`;
const can = el.querySelector("#draw");
const ctx = can.getContext("2d");
ctx.fillStyle = "white"; ctx.fillRect(0,0,SIZE,SIZE);
ctx.lineCap = ctx.lineJoin = "round";
ctx.lineWidth = 5; ctx.strokeStyle = "black";
let drawing = false;
const pos = e => {
const r = can.getBoundingClientRect();
return {x:e.clientX-r.left, y:e.clientY-r.top};
};
can.addEventListener("pointerdown", e=>{
drawing = true; const {x,y}=pos(e); ctx.beginPath(); ctx.moveTo(x,y);
});
can.addEventListener("pointermove", e=>{
if(!drawing) return; const {x,y}=pos(e); ctx.lineTo(x,y); ctx.stroke();
});
["pointerup","pointerleave","pointercancel"]
.forEach(ev=>can.addEventListener(ev,()=>drawing=false));
// Clear
el.querySelector("#clear").onclick = ()=>{
ctx.clearRect(0,0,SIZE,SIZE);
ctx.fillStyle="white"; ctx.fillRect(0,0,SIZE,SIZE);
};
// Send
el.querySelector("#send").onclick = ()=>{
const rgba = ctx.getImageData(0,0,SIZE,SIZE).data;
const dv = new DataView(rgba.buffer);
model.set('_sketch', dv); // sync to Python
model.save_changes();
};
}
"""
Usage#
The following code sets up an interactive tool that allows users to draw a sketch and instantly see retrieval results. It uses the SketchCanvas
widget to capture user input. When the user finishes drawing, the on_canvas_drawn
function is triggered. This function clears the display and converts the sketch to a NumPy array. Then, you should add your own code to preprocess the sketch, extract features using the model, and perform retrieval.
from widget import SketchCanvas
canvas = SketchCanvas()
def on_canvas_drawn(change):
display(clear=True)
sketch = canvas.to_array()
# TODO:
# - Extract features from the sketch
# - Comput the sketch embedding
# - Compute the similarity matrix with the photo gallery
# - Visualize the retrieval results
canvas.add_listener(on_canvas_drawn)
In another cell, you should display the canvas widget as follows.
display(canvas)