where()

This function is used to return the new tensor by checking the existing tensors conditionally.

Syntax: torch.where(condition,statement1,statement2)

where,

  • condition is used to check the existing tensor condition by applying conditions on the existing tensors
  • statememt1 is executed when condition is true
  • statememt2 is executed when condition is false

Example: We will use different relational operators to check the functionality

Python3




# import module
import torch
  
# create a tensor with 3 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89],
                      [23, 45, 67]]])
  
# display
print(data)
  
# set the number 100 when the
# number in greater than 45
# otherwise 50
print(torch.where(data > 45, 100, 50))
  
# set the number 100 when the
# number in less than 45
# otherwise 50
print(torch.where(data < 45, 100, 50))
  
# set the number 100 when the number in 
# equal to 23 otherwise 50
print(torch.where(data == 23, 100, 50))


Output:

tensor([[[10, 20, 30],
         [45, 67, 89],
         [23, 45, 67]]])
tensor([[[ 50,  50,  50],
         [ 50, 100, 100],
         [ 50,  50, 100]]])
tensor([[[100, 100, 100],
         [ 50,  50,  50],
         [100,  50,  50]]])
tensor([[[ 50,  50,  50],
         [ 50,  50,  50],
         [100,  50,  50]]])


Tensor Operations in PyTorch

In this article, we will discuss tensor operations in PyTorch.

PyTorch is a scientific package used to perform operations on the given data like tensor in python. A Tensor is a collection of data like a numpy array. We can create a tensor using the tensor function:

Syntax: torch.tensor([[[element1,element2,.,element n],……,[element1,element2,.,element n]]])

where,

  • torch is the module
  • tensor is the function
  • elements are the data

The Operations in PyTorch that are applied on tensor are:

Similar Reads

expand()

This operation is used to expand the tensor into a number of tensors, a number of rows in tensors, and a number of columns in tensors....

permute()

...

tolist()

This is used to reorder the tensor using row and column...

narrow()

...

where()

This method is used to return a list or nested list from the given tensor....