Multi-task learning is a mechanism of learning multiple tasks in parallel, by a shared model. The goal of MTL is to improve generalization by taking advantage of domain-specific knowledge contained in the training samples of related tasks. MTL approach learns the tasks more quickly and proficiently with less amount of data and reduced overfitting than learning them independently. However there some design challenges involved in multi-task learning. Different tasks may have conflicting needs. Trying to improve the performance of one task may harm the performance of another task, a phenomenon known as negative transfer. Combining the losses is another challenge. This post gives a general idea of various multi-task learning architectures available.
In the context of deep learning, multitask learning is done by hard parameter sharing that relates to the architecture design of sharing model weights between tasks to minimize multiple loss functions. Soft parameter sharing relates to having individual task specific models with separate weights for different tasks, but the distance between the model parameters of tasks are added to the joint objective function.
Multi-task Architectures
The architectures focus on the portion of model’s parameters that will be shared between the tasks and also on how to combine and parameterize task-specific and shared modules. Sharing more across the tasks may lead to negative transfer, causing worst performance, and sharing less may lead to overfitting. The best architectures are those which balance the sharing.
Shared trunk:
Most of the multitasking architectures for computer vision followed a template architecture, with a series of convolutional layers that are shared between all the tasks acting as a base feature extractor followed by task specific output heads which use the extracted features as input.
Tasks-Constrained Deep Convolutional Network (TCDCN) architecture is an example of traditional shared trunk architecture.
Multitask network cascade is similar architecture as of TCDCN, with a small difference in the task specific layers. The output of each task specific brach is given as input to the next task and so on forming a cascade.
Multi-gate mixture of experts explicitly learns to model task relationships from data. The structure is a mixture of expert sub-models shared across all the tasks while having a gating network to optimize each task.
Cross-Talk:
Cross-talk architecture has a separate network for each task, with information flow between parallel layers in the task networks.
Each task in a cross-stitch network shown in fig.5 is composed of an individual network. The input to each layer is a linear combination of outputs from each task network of the previous layer. The weights are learned and are task specific.
A generalized cross stitch network for feature fusion was introduced called Neural Discriminative Dimensionality Reduction(NDDR-CNN). The outputs of each task network layers are concatenated and passed through a 1x1 convolution to fuse the features.
Prediction Distillation:
Prediction distillation techniques are based on the ideas of multi-task learning, features learned from one task may be helpful in performing another related task.
PAD-net is a prediction distillation architecture, where the preliminary predictions are made for each task and then the predictions are combined using a multi-modal network to produce a refined final output.
NLP Architecture, BERT for MTL:
The recent development in neural architecture for NLP has brought several changes with traditional feed forward architectures evolving into recurrent models and recurrent models being replaced by attention based architectures. These changes in NLP are used for multitask learning.
MT-BERT model has shared layers based on BERT. The input sequences are converted into a sequence of embedding vectors by BERT model in the shared layers. Contextual information is gathered by applying the attention mechanism. The BERT model encodes the information in a vector for each token. For each task, a fully connected task specific layer is used on top of BERT layer.
Multi-Modal Architecture:
Until now, we discussed some of the multi-task architectures specifically designed to handle data in one fixed domain. In the case of multi-task multi-modal, representations are shared across tasks and across modes, by which the learned representations are generalized by another layer of abstraction.
OmniNet is a unified multimodal architecture to enable learning multi-modal tasks with multiple input domains and support generic multi-tasking for any set of tasks. OmniNet consists of multiple peripheral networks connected to a common Central Neural Processor(CNP). Each peripheral network is used to connect domain specific input into feature representation. The output of CNP is then passed to several task specific output heads.
Learned Architectures:
Learning to learn the architecture and weights of the resulting model is another approach to architecture design of multi-task learning. A high degree of sharing happens between similar tasks than unrelated tasks, which can overcome the issue of negative transfer between tasks. This method facilitates an adaptive sharing between tasks to a level of precision that isn’t feasible by hand designed shared architectures.
In the case of Learned branching architecture, each task shares all layers of the network at the beginning of the training. As the training goes on, less related tasks branch into clusters, so that only highly related tasks share as many parameters.
Fine-grained parameter sharing allows more flexible information flow between those tasks than sharing at the multi-layer level. Each task has a sparse subnetwork that is extracted using an Iterative Magnitude Pruning by training for a small number of epochs. Extracted subnetwork for each task will overlap and exhibit a fine-grained parameter sharing between the tasks.
Conclusion
This post gives a review of the multi-task learning field covering architecture design. The key techniques for multi-task neural networks are shared feature extractors followed by task specific layers, learning what to share, fine-grained parameter sharing, varied parameter sharing, and sharing and recombination. There are two more techniques to be discussed which are optimization techniques and task relationship learning.
The major content of this post is extracted from the paper “MULTI-TASK LEARNING WITH DEEP NEURAL NETWORKS: A SURVEY”. The author has explained most of the architectures, optimization techniques, and task learning methods in detail. But the author has not mentioned much about the state-of-the-art architectures for NLP. Comparing the performance of the models would have given a better idea of different multi-task models.
As part of future research, task balancing still needs to be improved. Techniques to learn the underlying concepts which are applicable between the tasks must be created and be applied to new and unfamiliar situations.
Source:
Multi-task Learning with Deep Neural Networks: A Survey
An Empirical Study of Multi-Task Learning on BERT for Biomedical Text Mining
OmniNet: A unified architecture for multi-modal multi-task learning