# TimmWrapper

<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

## Overview

Helper class to enable loading timm models to be used with the transformers library and its autoclasses.

```python
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor

>>> # Load image
>>> image = Image.open(urlopen(
...     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))

>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()

>>> # Preprocess image
>>> inputs = image_processor(image)

>>> # Forward pass
>>> with torch.no_grad():
...     logits = model(**inputs).logits

>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
```

## Resources:

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with TimmWrapper.

<PipelineTag pipeline="image-classification"/>

- [Collection of Example Notebook](https://github.com/ariG23498/timm-wrapper-examples) 🌎

> [!TIP]
> For a more detailed overview please read the [official blog post](https://huggingface.co/blog/timm-transformers) on the timm integration.

## TimmWrapperConfig[[transformers.TimmWrapperConfig]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.TimmWrapperConfig</name><anchor>transformers.TimmWrapperConfig</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py#L31</source><parameters>[{"name": "architecture", "val": ": str = 'resnet50'"}, {"name": "initializer_range", "val": ": float = 0.02"}, {"name": "do_pooling", "val": ": bool = True"}, {"name": "model_args", "val": ": typing.Optional[dict[str, typing.Any]] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **architecture** (`str`, *optional*, defaults to `"resnet50"`) --
  The timm architecture to load.
- **initializer_range** (`float`, *optional*, defaults to 0.02) --
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- **do_pooling** (`bool`, *optional*, defaults to `True`) --
  Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
- **model_args** (`dict[str, Any]`, *optional*) --
  Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}`
  for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`.</paramsdesc><paramgroups>0</paramgroups></docstring>

This is the configuration class to store the configuration for a timm backbone `TimmWrapper`.

It is used to instantiate a timm model according to the specified arguments, defining the model.

Configuration objects inherit from [PretrainedConfig](/docs/transformers/v4.57.0/en/main_classes/configuration#transformers.PretrainedConfig) and can be used to control the model outputs. Read the
documentation from [PretrainedConfig](/docs/transformers/v4.57.0/en/main_classes/configuration#transformers.PretrainedConfig) for more information.

Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
imagenet models is set to `None` due to occlusions in the label descriptions.



<ExampleCodeBlock anchor="transformers.TimmWrapperConfig.example">

Example:
```python
>>> from transformers import TimmWrapperModel

>>> # Initializing a timm model
>>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k")

>>> # Accessing the model configuration
>>> configuration = model.config
```

</ExampleCodeBlock>


</div>

## TimmWrapperImageProcessor[[transformers.TimmWrapperImageProcessor]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.TimmWrapperImageProcessor</name><anchor>transformers.TimmWrapperImageProcessor</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py#L39</source><parameters>[{"name": "pretrained_cfg", "val": ": dict"}, {"name": "architecture", "val": ": typing.Optional[str] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **pretrained_cfg** (`dict[str, Any]`) --
  The configuration of the pretrained model used to resolve evaluation and
  training transforms.
- **architecture** (`Optional[str]`, *optional*) --
  Name of the architecture of the model.</paramsdesc><paramgroups>0</paramgroups></docstring>

Wrapper class for timm models to be used within transformers.





<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>preprocess</name><anchor>transformers.TimmWrapperImageProcessor.preprocess</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py#L97</source><parameters>[{"name": "images", "val": ": typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']]"}, {"name": "return_tensors", "val": ": typing.Union[str, transformers.utils.generic.TensorType, NoneType] = 'pt'"}]</parameters><paramsdesc>- **images** (`ImageInput`) --
  Image to preprocess. Expects a single or batch of images
- **return_tensors** (`str` or `TensorType`, *optional*) --
  The type of tensors to return.</paramsdesc><paramgroups>0</paramgroups></docstring>

Preprocess an image or batch of images.




</div></div>

## TimmWrapperModel[[transformers.TimmWrapperModel]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.TimmWrapperModel</name><anchor>transformers.TimmWrapperModel</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py#L154</source><parameters>[{"name": "config", "val": ": TimmWrapperConfig"}]</parameters></docstring>

Wrapper class for timm models to be used in transformers.



<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>forward</name><anchor>transformers.TimmWrapperModel.forward</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py#L167</source><parameters>[{"name": "pixel_values", "val": ": FloatTensor"}, {"name": "output_attentions", "val": ": typing.Optional[bool] = None"}, {"name": "output_hidden_states", "val": ": typing.Union[bool, list[int], NoneType] = None"}, {"name": "return_dict", "val": ": typing.Optional[bool] = None"}, {"name": "do_pooling", "val": ": typing.Optional[bool] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **pixel_values** (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`) --
  The tensors corresponding to the input images. Pixel values can be obtained using
  [TimmWrapperImageProcessor](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperImageProcessor). See [TimmWrapperImageProcessor.__call__()](/docs/transformers/v4.57.0/en/model_doc/fuyu#transformers.FuyuImageProcessor.__call__) for details (`processor_class` uses
  [TimmWrapperImageProcessor](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperImageProcessor) for processing images).
- **output_attentions** (`bool`, *optional*) --
  Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
- **output_hidden_states** (`bool`, *optional*) --
  Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
- **return_dict** (`bool`, *optional*) --
  Whether or not to return a [ModelOutput](/docs/transformers/v4.57.0/en/main_classes/output#transformers.utils.ModelOutput) instead of a plain tuple.
- **do_pooling** (`bool`, *optional*) --
  Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the
  `do_pooling` value from the config is used.</paramsdesc><paramgroups>0</paramgroups><rettype>`transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput` or `tuple(torch.FloatTensor)`</rettype><retdesc>A `transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput` or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
elements depending on the configuration ([TimmWrapperConfig](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperConfig)) and inputs.

- **last_hidden_state** (`<class 'torch.FloatTensor'>.last_hidden_state`) -- The last hidden state of the model, output before applying the classification head.
- **pooler_output** (`torch.FloatTensor`, *optional*) -- The pooled output derived from the last hidden state, if applicable.
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`) -- A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers.
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True` is set or if `config.output_attentions=True`.) -- A tuple containing the intermediate attention weights of the model at the output of each layer.
  Note: Currently, Timm models do not support attentions output.</retdesc></docstring>
The [TimmWrapperModel](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperModel) forward method, overrides the `__call__` special method.

<Tip>

Although the recipe for forward pass needs to be defined within this function, one should call the `Module`
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.

</Tip>







<ExampleCodeBlock anchor="transformers.TimmWrapperModel.forward.example">

Examples:
```python
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModel, AutoImageProcessor

>>> # Load image
>>> image = Image.open(urlopen(
...     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))

>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModel.from_pretrained(checkpoint).eval()

>>> # Preprocess image
>>> inputs = image_processor(image)

>>> # Forward pass
>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> # Get pooled output
>>> pooled_output = outputs.pooler_output

>>> # Get last hidden state
>>> last_hidden_state = outputs.last_hidden_state
```

</ExampleCodeBlock>


</div></div>

## TimmWrapperForImageClassification[[transformers.TimmWrapperForImageClassification]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.TimmWrapperForImageClassification</name><anchor>transformers.TimmWrapperForImageClassification</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py#L269</source><parameters>[{"name": "config", "val": ": TimmWrapperConfig"}]</parameters></docstring>

Wrapper class for timm models to be used in transformers for image classification.



<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>forward</name><anchor>transformers.TimmWrapperForImageClassification.forward</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py#L291</source><parameters>[{"name": "pixel_values", "val": ": FloatTensor"}, {"name": "labels", "val": ": typing.Optional[torch.LongTensor] = None"}, {"name": "output_attentions", "val": ": typing.Optional[bool] = None"}, {"name": "output_hidden_states", "val": ": typing.Union[bool, list[int], NoneType] = None"}, {"name": "return_dict", "val": ": typing.Optional[bool] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **pixel_values** (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`) --
  The tensors corresponding to the input images. Pixel values can be obtained using
  [TimmWrapperImageProcessor](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperImageProcessor). See [TimmWrapperImageProcessor.__call__()](/docs/transformers/v4.57.0/en/model_doc/fuyu#transformers.FuyuImageProcessor.__call__) for details (`processor_class` uses
  [TimmWrapperImageProcessor](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperImageProcessor) for processing images).
- **labels** (`torch.LongTensor` of shape `(batch_size,)`, *optional*) --
  Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- **output_attentions** (`bool`, *optional*) --
  Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
- **output_hidden_states** (`bool`, *optional*) --
  Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
- **return_dict** (`bool`, *optional*) --
  Whether or not to return a [ModelOutput](/docs/transformers/v4.57.0/en/main_classes/output#transformers.utils.ModelOutput) instead of a plain tuple.
  **kwargs:
  Additional keyword arguments passed along to the `timm` model forward.</paramsdesc><paramgroups>0</paramgroups><rettype>[transformers.modeling_outputs.ImageClassifierOutput](/docs/transformers/v4.57.0/en/main_classes/output#transformers.modeling_outputs.ImageClassifierOutput) or `tuple(torch.FloatTensor)`</rettype><retdesc>A [transformers.modeling_outputs.ImageClassifierOutput](/docs/transformers/v4.57.0/en/main_classes/output#transformers.modeling_outputs.ImageClassifierOutput) or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
elements depending on the configuration ([TimmWrapperConfig](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperConfig)) and inputs.

- **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) -- Classification (or regression if config.num_labels==1) loss.
- **logits** (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`) -- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
  (also called feature maps) of the model at the output of each stage.
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
  sequence_length)`.

  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  heads.</retdesc></docstring>
The [TimmWrapperForImageClassification](/docs/transformers/v4.57.0/en/model_doc/timm_wrapper#transformers.TimmWrapperForImageClassification) forward method, overrides the `__call__` special method.

<Tip>

Although the recipe for forward pass needs to be defined within this function, one should call the `Module`
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.

</Tip>







<ExampleCodeBlock anchor="transformers.TimmWrapperForImageClassification.forward.example">

Examples:
```python
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor

>>> # Load image
>>> image = Image.open(urlopen(
...     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))

>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()

>>> # Preprocess image
>>> inputs = image_processor(image)

>>> # Forward pass
>>> with torch.no_grad():
...     logits = model(**inputs).logits

>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
```

</ExampleCodeBlock>


</div></div>

<EditOnGithub source="https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/timm_wrapper.md" />