where()
This function is used to return the new tensor by checking the existing tensors conditionally.
Syntax: torch.where(condition,statement1,statement2)
where,
- condition is used to check the existing tensor condition by applying conditions on the existing tensors
- statememt1 is executed when condition is true
- statememt2 is executed when condition is false
Example: We will use different relational operators to check the functionality
Python3
# import module import torch # create a tensor with 3 data in # 3 three elements each data = torch.tensor([[[ 10 , 20 , 30 ], [ 45 , 67 , 89 ], [ 23 , 45 , 67 ]]]) # display print (data) # set the number 100 when the # number in greater than 45 # otherwise 50 print (torch.where(data > 45 , 100 , 50 )) # set the number 100 when the # number in less than 45 # otherwise 50 print (torch.where(data < 45 , 100 , 50 )) # set the number 100 when the number in # equal to 23 otherwise 50 print (torch.where(data = = 23 , 100 , 50 )) |
Output:
tensor([[[10, 20, 30], [45, 67, 89], [23, 45, 67]]]) tensor([[[ 50, 50, 50], [ 50, 100, 100], [ 50, 50, 100]]]) tensor([[[100, 100, 100], [ 50, 50, 50], [100, 50, 50]]]) tensor([[[ 50, 50, 50], [ 50, 50, 50], [100, 50, 50]]])
Tensor Operations in PyTorch
In this article, we will discuss tensor operations in PyTorch.
PyTorch is a scientific package used to perform operations on the given data like tensor in python. A Tensor is a collection of data like a numpy array. We can create a tensor using the tensor function:
Syntax: torch.tensor([[[element1,element2,.,element n],……,[element1,element2,.,element n]]])
where,
- torch is the module
- tensor is the function
- elements are the data
The Operations in PyTorch that are applied on tensor are: