Introduction to the Segment Anything iOS App Link to this heading

Segment Anything is an image segmentation app available for iPhone or iPad. The app is based on the open-source SAM (Segment Anything Model). All processing is done locally on your iPhone or iPad, requiring no network connection. The app has been optimized for smooth, reliable performance on your device.

Exporting to Onnx Format Link to this heading

SAM relies on two models: the image encoder (vit) for extracting image features and the mask_decoder for obtaining the final segmentation mask.

Export Image Encoder to Onnx Link to this heading

python
 1sam_checkpoint = "./model/sam_vit_b_01ec64.pth"
 2model_type = "vit_b"
 3
 4device = "cpu"
 5sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
 6sam.to(device=device)
 7
 8# Load input data
 9input_data = np.load("./model/image_encoder/input.npy")
10input_tensor = torch.from_numpy(input_data).float()
11
12# Export the image encoder model using torch.onnx
13torch.onnx.export(
14    sam.image_encoder,                      # model instance
15    input_tensor,                           # input tensor
16    "./model/image_encoder/image_encoder_vit.onnx",  # output ONNX file name
17    verbose=True,
18    export_params=True,                     # whether to export model parameters
19    # opset_version=11,                      # ONNX opset version
20    do_constant_folding=True,               # whether to perform constant folding optimization
21    input_names=["input"],                  # input node names
22    output_names=["output"],                # output node names
23)

1.2 Export Mask Decoder to Onnx Link to this heading

To simplify the complexity of exporting the ONNX model, we have removed a few inputs: mask_input, has_mask_input, and orig_im_size. The orig_im_size will be defaulted to 1024x1024, which will greatly simplify the handling of variable-length dimensions in ONNX, making it easier for us to port it to iOS later. We temporarily do not need mask_input, so it is directly removed.

python
 1# Export mask generator
 2import warnings
 3onnx_model = SamOnnxModel(sam, return_single_mask=False, return_extra_metrics=True)
 4
 5dynamic_axes = {
 6    # "point_coords": {1: "num_points"},
 7    # "point_labels": {1: "num_points"},
 8}
 9
10embed_dim = sam.prompt_encoder.embed_dim
11embed_size = sam.prompt_encoder.image_embedding_size
12mask_input_size = [4 * x for x in embed_size]
13batch_size = 1
14num_points = 4
15dummy_inputs = {
16    "image_embeddings": torch.randn(batch_size, embed_dim, *embed_size, dtype=torch.float),
17    "point_coords": torch.randint(low=0, high=1024, size=(batch_size, num_points, 2), dtype=torch.float),
18    "point_labels": torch.randint(low=0, high=4, size=(batch_size, num_points), dtype=torch.float),
19    # "mask_input": torch.randn(batch_size, 1, *mask_input_size, dtype=torch.float),
20    # "has_mask_input": torch.tensor([batch_size], dtype=torch.float),
21    # "orig_im_size": torch.tensor([1024, 1024], dtype=torch.float),
22}
23output_names = [
24    "masks",
25    "scores",
26    "stability_scores",
27    "areas",
28    "low_res_masks"]
29
30onnx_model_path = "./model/mask_decoder/mask_decoder." + str(num_points) + ".onnx"
31
32with warnings.catch_warnings():
33    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
34    warnings.filterwarnings("ignore", category=UserWarning)
35    with open(onnx_model_path, "wb") as f:
36        torch.onnx.export(
37            onnx_model,
38            tuple(dummy_inputs.values()),
39            f,
40            export_params=True,
41            verbose=True,
42            opset_version=11,
43            do_constant_folding=True,
44            input_names=list(dummy_inputs.keys()),
45            output_names=output_names,
46            dynamic_axes=dynamic_axes,
47        )    
48
49print("model exported")

Running the ONNX Model on iOS Devices Link to this heading

Utilizing MPSX Library Link to this heading

To run the ONNX model on iOS, we leveraged a closed-source modification of the MPSX library. MPSX is an excellent open-source project that allows you to load ONNX models on iOS using Swift and perform inference in a straightforward manner.

Enhancements to MPSX Link to this heading

We made extensive enhancements to the MPSX library to support a more comprehensive set of ONNX operators and offer a more flexible way to invoke the model. These modifications enabled us to integrate the ONNX model seamlessly into our iOS application.

Code Snippet: Loading and Running the ONNX Model Link to this heading

Below is a Swift code snippet that demonstrates how to load the ONNX model and perform inference:

swift
1let graph = buildGraphVit(path: folder + "image_encoder_vit.sim.f16.onnx",
2                          floatPrecision: .float16, 
3                          input: "input", 
4                          output: "output", 
5                          inputShape: Shape([1, 3, 1024, 1024]))
6        
7let input = Tensor.loadFromNpy(path: folder + "input.npy")!
8let output = graph.forward(inputs: ["input": input], outputs: ["output"])["output"]!

Interesting Tidbits Link to this heading

It’s worth noting that we encountered the following error when running inference on the ViT model on iOS 17:

bash
1Input N1D1C133H133W128 and output N1D19C19H7W128 tensors must have the same number of elements

The root cause of this issue lies in the /layers.1/blocks.0/reshape layer. The input tensor shape is [1, 133, 133, 128], and the output tensor shape is [1, 19, 7, 19, 7, 128]. This Reshape operation was not a problem on iOS 16 and earlier versions but throws an error on iOS 17.

After much deliberation, we found a workaround. We forced the output tensor shape from a 6-dimensional tensor to a 5-dimensional tensor, [19, 7, 19, 7, 128]. This does not change the semantics (we assume that the batch size is always 1 in our application). This method successfully bypasses the error.