Caching values on a module level and unit testing - python-3.x

Below is a module for querying and caching AWS STS tokens, the intention is to avoid querying STS if there is a valid token.
class Credentials:
def __init__(self):
self.sts_credentials = None
self.token_expiry_time = None
def is_token_expired(self):
current_time_with_buffer = datetime.now() + timedelta(minutes=2)
return not self.token_expiry_time or self.token_expiry_time < current_time_with_buffer
CREDENTIALS_ = Credentials()
def get_credentials():
if CREDENTIALS_.is_token_expired():
sts_client = boto3.client('sts')
LOGGER.info("The credentials are either empty or expiring, refreshing")
try:
sts_token = sts_client.assume_role(
RoleArn=os.environ["KINESIS_ASSUME_ROLE"],
RoleSessionName=str(uuid.uuid4()))
except Exception as e:
LOGGER.error(f"Error occurred while trying to assume role with {os.environ['KINESIS_ASSUME_ROLE']}", e)
raise e
CREDENTIALS_.sts_credentials = {
"aws_access_key_id": sts_token['Credentials']['AccessKeyId'],
"aws_secret_access_key": sts_token['Credentials']['SecretAccessKey'],
"aws_session_token": sts_token['Credentials']['SessionToken']
}
CREDENTIALS_.token_expiry_time = sts_token["Credentials"]["Expiration"]
return CREDENTIALS_.sts_credentials
One of the unit tests is as below, this passes in isolation, but fails when run alongside other tests, the reason being CREDENTIALS_ variable, which is modified by other tests, I can set this value to None, but I want to know what is the cleaner way of clearing the cached value
def test_get_credentials_refreshes_token_if_about_to_expire(sts_response, credentials):
with mock.patch("boto3.client") as mock_boto_client:
mock_assume_role = mock_boto_client.return_value.assume_role
mock_assume_role.return_value = sts_response
get_credentials()
actual_credentials = get_credentials()
calls = [call('sts'),
call().assume_role(RoleArn='arn:aws:iam::000000000000:role/dummyarn', RoleSessionName=ANY),
call('sts'),
call().assume_role(RoleArn='arn:aws:iam::000000000000:role/dummyarn', RoleSessionName=ANY)]
assert credentials == actual_credentials
mock_boto_client.assert_has_calls(calls)

The cleaner way would be to make sure that your unit tests are performing unit tests. This means that for every unit there should be no interaction with other units. Since you are using a global variable CREDENTIALS_, this is going to be nearly impossible.
1) easy fix
An easy fix would be to pass CREDENTIALS_ as input argument. Then you can create a fake CREDENTIALS_ object during each of the tests, that are tailored to your test conditions.
2) Better fix
A better solution would be, besides using the credential input argument, to break up the logic inside the get_credentials. By splitting it into smaller functions, you can separate the server logic and the credential updating. Making it easier to Mock and test. A possible division of the whole function would be:
get_sts_token
update_credentials
get_credentials
Now the get_sts_token has connections to the server, but the update_credentials and get_credentials do not have to directly interact with it.
Code
Example 1)
def update_credentials(credentials):
if credentials.is_token_expired():
sts_client = boto3.client('sts')
LOGGER.info("The credentials are either empty or expiring, refreshing")
try:
sts_token = sts_client.assume_role(
RoleArn=os.environ["KINESIS_ASSUME_ROLE"],
RoleSessionName=str(uuid.uuid4()))
except Exception as e:
LOGGER.error(f"Error occurred while trying to assume role with {os.environ['KINESIS_ASSUME_ROLE']}", e)
raise e
credentials.sts_credentials = {
"aws_access_key_id": sts_token['Credentials']['AccessKeyId'],
"aws_secret_access_key": sts_token['Credentials']['SecretAccessKey'],
"aws_session_token": sts_token['Credentials']['SessionToken']
}
credentials.token_expiry_time = sts_token["Credentials"]["Expiration"]
return credentials
# Where you need the credentials
CREDENTIALS_ = update_credentials(CREDENTIALS_)
CREDENTIALS_.sts_credentials
Now you can insert your own CREDENTIALS_ object in the test.
Example 2)
def get_sts_token():
sts_client = boto3.client('sts')
LOGGER.info("The credentials are either empty or expiring, refreshing")
try:
sts_token = sts_client.assume_role(
RoleArn=os.environ["KINESIS_ASSUME_ROLE"],
RoleSessionName=str(uuid.uuid4()))
except Exception as e:
LOGGER.error(f"Error occurred while trying to assume role with {os.environ['KINESIS_ASSUME_ROLE']}", e)
raise e
return sts_token
def update_credentials(credentials, sts_token):
credentials.sts_credentials = {
"aws_access_key_id": sts_token['Credentials']['AccessKeyId'],
"aws_secret_access_key": sts_token['Credentials']['SecretAccessKey'],
"aws_session_token": sts_token['Credentials']['SessionToken']
}
return credentials
def get_credentials(credentials: Credentials):
if credentials.is_token_expired():
sts_token = get_sts_token()
credentials = update_credentials(credentials, sts_token)
return credentials.sts_credentials

Related

Is it possible to use SQLite in EFS reliably?

Is it possible to use SQLite in AWS EFS safely? In my readings trying to determine if this is viable there appears to be some allusions that it should be doable since AWS EFS implemented NFSv4 back in 2017. In practice I am having no luck getting consistent behavior out of it.
Quick Points:
"Just use AWS RDS": Due to issues with other AWS architecture another team has implemented we are trying to work around resource starving cause by the API (DynamoDB isn't an option)
"This goes against SQLite's primary use case (being a locally access DB): Yes, but given the circumstances it seems like the best approach.
I have verified that we are running nfsv4 on our EC2 instance
Current results are very inconsistent with 3 exceptions encountered irrespective of approach I use
"file is encrypted or is not a database"
"disk I/O error (potentially related to EFS open file limits)"
"database disk image is malformed" (The database actually isn't corrupted after this)
database code:
SQLITE_VAR_LIMIT = 999
dgm_db_file_name = ''
db = SqliteExtDatabase(None)
lock_file = f'{os.getenv("efs_path", "tmp")}/db_lock_file.lock'
def lock_db_file():
with open(lock_file, 'w+') as lock:
limit = 900
while limit:
try:
fcntl.flock(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
print(f'db locked')
break
except Exception as e:
print(f'Exception: {str(e)}')
limit -= 1
time.sleep(1)
if not limit:
raise ValueError(f'Timed out after 900 seconds while waiting for database lock.')
def unlock_db_file():
with open(lock_file, 'w+') as lock:
fcntl.flock(lock, fcntl.LOCK_UN)
print(f'db unlocked')
def initialize_db(db_file_path=dgm_db_file_name):
print(f'Initializing db ')
global db
db.init(db_file_path, pragmas={
'journal_mode': 'wal',
'cache_size': -1 * 64000, # 64MB
'foreign_keys': 1})
print(f'db initialized')
class Thing(Model):
name = CharField(primary_key=True)
etag = CharField()
last_modified = CharField()
class Meta:
database = db
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
#staticmethod
def insert_many(stuff):
data = [(k, v['ETag'], v['Last-Modified']) for k, v in stuff.items()]
fields = [Thing.name, Thing.etag, Thing.last_modified]
limit = 900
while True:
try:
with db.atomic():
for key_batch in chunked(data, SQLITE_VAR_LIMIT // len(fields)):
s = Thing.insert_many(key_batch, fields=[Thing.name, Thing.etag, Thing.last_modified]) \
.on_conflict_replace().execute()
break
except Exception as e:
print(f'Exception: {str(e)}')
print(f'Will try for {limit} more seconds.')
limit -= 1
time.sleep(1)
if not limit:
raise ValueError('Failed to exectue query after 900 seconds.')
Example Call:
print(f'Critical section start')
# lock_db_file() # I have tried with a secondary lock file as well
self.stuff_db = Thing()
if not Path(self.db_file_path).exists():
initialize_db(self.db_file_path)
print('creating tables')
db.create_tables([Thing], safe=True)
else:
initialize_db(self.db_file_path)
getattr(Thing, insert_many)(self.stuff_db, stuff_db)
# db.close()
# unlock_db_file()
print(f'Critical section end')
print(f'len after update: {len(stuff)}')
Additional peculiarities:
If a lamda gets stuck catching the "malformed image" exception and a new lambda execution is triggered, the error resolves in the other lambda.
After some trial and error I discovered it is a workable solution. It appears that the design will need to use APSWDatabase(..., vfs='unix-excl') to properly enforce locking.
Database code:
from peewee import *
from playhouse.apsw_ext import APSWDatabase
SQLITE_VAR_LIMIT = 999
db = APSWDatabase(None, vfs='unix-excl')
def initialize_db(db_file_path):
global db
db.init(db_file_path, pragmas={
'journal_mode': 'wal',
'cache_size': -1 * 64000})
db.create_tables([Thing], safe=True)
return Thing()
class Thing(Model):
field_1 = CharField(primary_key=True)
field_2 = CharField()
field_3 = CharField()
class Meta:
database = db
This allows for the following usage:
db_model = initialize_db(db_file_path)
with db:
# Do database queries here with the db_model
pass
Note: If you don't use the context managed database connection you will need to explicitly call db.close() otherwise the lock will not be released from the file. Additionally, calling db_init(...) causes a lock to be placed on the databased until it is closed.

How to use google-cloud-os-config classes in python code?

In a Google Cloud function (python 3.7) , I need to fetch the compliance state of all VMs in a given location in a project.
From available google documentation here I could see the REST API format:
https://cloud.google.com/compute/docs/os-configuration-management/view-compliance#view_compliance_state
On searching for the client library here , I found this:
class google.cloud.osconfig_v1alpha.types.ListInstanceOSPoliciesCompliancesRequest(mapping=None, *, ignore_unknown_fields=False, **kwargs)[source]
Bases: proto.message.Message
A request message for listing OS policies compliance data for all Compute Engine VMs in the given location.
parent
Required. The parent resource name.
Format: projects/{project}/locations/{location}
For {project}, either Compute Engine project-number or project-id can be provided.
Type
str
page_size
The maximum number of results to return.
Type
int
page_token
A pagination token returned from a previous call to ListInstanceOSPoliciesCompliances that indicates where this listing should continue from.
Type
str
filter
If provided, this field specifies the criteria that must be met by a InstanceOSPoliciesCompliance API resource to be included in the response.
Type
str
And the response class as:
class google.cloud.osconfig_v1alpha.types.ListInstanceOSPoliciesCompliancesResponse(mapping=None, *, ignore_unknown_fields=False, **kwargs)[source]
Bases: proto.message.Message
A response message for listing OS policies compliance data for all Compute Engine VMs in the given location.
instance_os_policies_compliances
List of instance OS policies compliance objects.
Type
Sequence[google.cloud.osconfig_v1alpha.types.InstanceOSPoliciesCompliance]
next_page_token
The pagination token to retrieve the next page of instance OS policies compliance objects.
Type
str
property raw_page
But I am not sure how to use this information in the python code.
I have written this but not sure if this is correct:
from google.cloud.osconfig_v1alpha.services.os_config_zonal_service import client
from google.cloud.osconfig_v1alpha.types import ListInstanceOSPoliciesCompliancesRequest
import logging
logger = logging.getLogger(__name__)
import os
def handler():
try:
project_id = os.environ["PROJECT_ID"]
location = os.environ["ZONE"]
#list compliance state
request = ListInstanceOSPoliciesCompliancesRequest(
parent=f"projects/{project}/locations/{location}")
response = client.instance_os_policies_compliance(request)
return response
except Exception as e:
logger.error("Unable to get compliance - %s " % str(e))
I could not find any usage example for the client library methods anywhere.
Could someone please help me here?
EDIT:
This is what I am using now:
from googleapiclient.discovery import build
def list_policy_compliance():
projectId = "my_project"
zone = "my_zone"
try:
service = build('osconfig', 'v1alpha', cache_discovery=False)
compliance_response = service.projects().locations(
).instanceOsPoliciesCompliances().list(
parent='projects/%s/locations/%s' % (
projectId, zone)).execute()
return compliance_response
except Exception as e:
raise Exception()
Something like this should work:
from google.cloud import os_config_v1alpha as osc
def handler():
client = osc.OsConfigZonalService()
project_id = "my_project"
location = "my_gcp_zone"
parent = f"projects/{project_id}/locations/{location}"
response = client.list_instance_os_policies_compliances(
parent=parent
)
# response is an iterable yielding
# InstanceOSPoliciesCompliance objects
for result in response:
# do something with result
...
You can also construct the request like this:
response = client.list_instance_os_policies_compliances(
request = {
"parent": parent
}
)
Answering my own question here , this is what I used:
from googleapiclient.discovery import build
def list_policy_compliance():
projectId = "my_project"
zone = "my_zone"
try:
service = build('osconfig', 'v1alpha', cache_discovery=False)
compliance_response = service.projects().locations(
).instanceOsPoliciesCompliances().list(
parent='projects/%s/locations/%s' % (
projectId, zone)).execute()
return compliance_response
except Exception as e:
raise Exception()

Execute customized script when launching instance using openstacksdk

I'm new to Openstack and I'm trying to create a tool so that I can launch any number of instances in an Openstack cloud. This was easily done using the nova-client module of openstacksdk.
Now the problem is that I want to make the instances execute a bash script as they are created by adding it as a userdata file, but it doesn't execute. This is confusing because I don't any error or warning message. Does anyone know what could it be?
Important parts of the code
The most important parts of the Python program are the function which gets the cloud info, the one that creates the instances and the main function, . I'll post them here as #Corey told.
"""
Function that allow us to log at cloud with all the credentials needed.
Username and password are not read from env.
"""
def get_nova_credentials_v2():
d = {}
user = ""
password = ""
print("Logging in...")
user = input("Username: ")
password = getpass.getpass(prompt="Password: ", stream=None)
while (user == "" or password == ""):
print("User or password field is empty")
user = input("Username: ")
password = getpass.getpass(prompt="Password: ", stream=None)
d['version'] = '2.65'
d['username'] = user
d['password'] = password
d['project_id'] = os.environ['OS_PROJECT_ID']
d['auth_url'] = os.environ['OS_AUTH_URL']
d['user_domain_name'] = os.environ['OS_USER_DOMAIN_NAME']
return d
Then we have the create_server function:
"""
This function creates a server using the info we got from JSON file
"""
def create_server(server):
s = {}
print("Creating "+server['compulsory']['name']+"...")
s['name'] = server['compulsory']['name']
s['image'] = server['compulsory']['os']
s['flavor'] = server['compulsory']['flavor']
s['min_count'] = server['compulsory']['copyNumber']
s['max_count'] = server['compulsory']['copyNumber']
s['userdata'] = server['file']
s['key_name'] = server['compulsory']['keyName']
s['availability_zone'] = server['compulsory']['availabilityZone']
s['nics'] = server['compulsory']['network']
print(s['userdata'])
if(exists("instalacion_k8s_docker.sh")):
print("Exists")
s['userdata'] = server['file']
nova.servers.create(**s)
And now the main function:
"""
Main process: First we create a connection to Openstack using our credentials.
Once connected we cal get_serverdata function to get all instance objects we want to be created.
We check that it is not empty and that we are not trying to create more instances than we are allowed.
Lastly we create the instances and the program finishes.
"""
credentials = get_nova_credentials_v2()
nova = client.Client(**credentials)
instances = get_serverdata()
current_instances = len(nova.servers.list())
if not instances:
print("No instance was writen. Check instances.json file.")
exit(3)
num = 0
for i in instances:
create_server(i)
exit(0)
For the rest of the code you can access to this public repo on github.
Thanks a lot!
Problem solved
The problem was the content of the server['file'] as #Corey said. It cannot be the Path to the file where you wrote the data but the content of it or a file type object. In the case of OpenstackSDK it must be base64 encoded but it is not the case in Novaclient.
Thanks a lot to #Corey for all the help! :)

How to set mocked exception behavior on Python?

I am using an external library (github3.py) that defines an internal exception (github3.exceptions.UnprocessableEntity). It doesn't matter how this exception is defined, so I want to create a side effect and set the attributes I use from this exception.
Tested code not-so-minimal example:
import github3
class GithubService:
def __init__(self, token: str) -> None:
self.connection = github3.login(token=token)
self.repos = self.connection.repositories()
def create_pull(self, repo_name: str) -> str:
for repo in self.repos:
if repo.full_name == repo_name:
break
try:
created_pr = repo.create_pull(
title="title",
body="body",
head="head",
base="base",
)
except github3.exceptions.UnprocessableEntity as github_exception:
extra = ""
for error in github_exception.errors:
if "message" in error:
extra += f"{error['message']} "
else:
extra += f"Invalid field {error['field']}. " # testing this case
return f"{repo_name}: {github_exception.msg}. {extra}"
I need to set the attributes msg and also errors from the exception. So I tried in my test code using pytest-mock:
#pytest.fixture
def mock_github3_login(mocker: MockerFixture) -> MockerFixture:
"""Fixture for mocking github3.login."""
mock = mocker.patch("github3.login", autospec=True)
mock.return_value.repositories.return_value = [
mocker.Mock(full_name="staticdev/nope"),
mocker.Mock(full_name="staticdev/omg"),
]
return mock
def test_create_pull_invalid_field(
mocker: MockerFixture, mock_github3_login: MockerFixture,
) -> None:
exception_mock = mocker.Mock(errors=[{"field": "head"}], msg="Validation Failed")
mock_github3_login.return_value.repositories.return_value[1].create_pull.side_effect = github3.exceptions.UnprocessableEntity(mocker.Mock())
mock_github3_login.return_value.repositories.return_value[1].create_pull.return_value = exception_mock
response = GithubService("faketoken").create_pull("staticdev/omg")
assert response == "staticdev/omg: Validation Failed. Invalid field head."
The problem with this code is that, if you have side_effect and return_value, Python just ignores the return_value.
The problem here is that I don't want to know the implementation of UnprocessableEntity to call it passing the right arguments to it's constructor. Also, I didn't find other way using just side_effect. I also tried to using return value and setting the class of the mock and using it this way:
def test_create_pull_invalid_field(
mock_github3_login: MockerFixture,
) -> None:
exception_mock = Mock(__class__ = github3.exceptions.UnprocessableEntity, errors=[{"field": "head"}], msg="Validation Failed")
mock_github3_login.return_value.repositories.return_value[1].create_pull.return_value = exception_mock
response = GithubService("faketoken").create_pull("staticdev/omg")
assert response == "staticdev/omg: Validation Failed. Invalid field head."
This also does not work, the exception is not thrown. So I don't know how to overcome this issue given the constraint I don't want to see the implementation of UnprocessableEntity. Any ideas here?
So based on your example, you don't really need to mock github3.exceptions.UnprocessableEntity but only the incoming resp argument.
So the following test should work:
def test_create_pull_invalid_field(
mocker: MockerFixture, mock_github3_login: MockerFixture,
) -> None:
mocked_response = mocker.Mock()
mocked_response.json.return_value = {
"message": "Validation Failed", "errors": [{"field": "head"}]
}
repo = mock_github3_login.return_value.repositories.return_value[1]
repo.create_pull.side_effect = github3.exceptions.UnprocessableEntity(mocked_response)
response = GithubService("faketoken").create_pull("staticdev/omg")
assert response == "staticdev/omg: Validation Failed. Invalid field head."
EDIT:
If you want github3.exceptions.UnprocessableEntity to be completely abstracted, it won't be possible to mock the entire class as catching classes that do not inherit from BaseException is not allowed (See docs). But you can get around it by mocking the constructor only:
def test_create_pull_invalid_field(
mocker: MockerFixture, mock_github3_login: MockerFixture,
) -> None:
def _initiate_mocked_exception(self) -> None:
self.errors = [{"field": "head"}]
self.msg = "Validation Failed"
mocker.patch.object(
github3.exceptions.UnprocessableEntity, "__init__",
_initiate_mocked_exception
)
repo = mock_github3_login.return_value.repositories.return_value[1]
repo.create_pull.side_effect = github3.exceptions.UnprocessableEntity
response = GithubService("faketoken").create_pull("staticdev/omg")
assert response == "staticdev/omg: Validation Failed. Invalid field head."

How to reference/return a value from SignalR?

This is the code I have:
from signalr_aio import Connection
if __name__ == "__main__":
# Create connection
# Users can optionally pass a session object to the client, e.g a cfscrape session to bypass cloudflare.
connection = Connection('https://beta.bittrex.com/signalr', session=None)
hub = connection.register_hub('c2')
hub.server.invoke('GetAuthContext', API_KEY) #Invoke 0 Creates the challenge that needs to be signed by the create_signature coroutine
signature = await create_signature(API_SECRET, challenge) #Creates the signature that needs to authenticated in the Authenticate query
hub.server.invoke('Authenticate', API_KEY, signature) #Invoke 1 authenticates user to account level information
connection.start()
What I have to do is verify my identity by getting a string-type challenge by the GetAuthContext call, then create a string-type signature using that challenge, and then pass that signature to the Authenticatecall. The problem I'm having is that that I need to enter the return value of the GetAuthContext into the challenge parameter of the create_signature coroutine. I'm guessing from the comment next to the below example that every invoke method gets marked as I([index of method]), so I would have to do signature = await create_signature(API_SECRET, 'I(0)')
async def on_debug(**msg):
# In case of `queryExchangeState` or `GetAuthContext`
if 'R' in msg and type(msg['R']) is not bool:
# For the simplicity of the example I(1) corresponds to `Authenticate` and I(0) to `GetAuthContext`
# Check the main body for more info.
if msg['I'] == str(2):
decoded_msg = await process_message(msg['R'])
print(decoded_msg)
elif msg['I'] == str(3):
signature = await create_signature(API_SECRET, msg['R'])
hub.server.invoke('Authenticate', API_KEY, signature)
Later this example gets assigned to connection.received ( connection.received += on_debug ) so I'm guessing that after connection.start() I have to put connection.recieved() to call the on_debug coroutine which will verify me, but for now I just want to understand how to reference the .invoke() methods to use within a function or coroutine.
I am far from an expert, but the feed from Bittrex is indeed a Dictionary.
for i in range(0, len(decoded_msg['D'])):
print('The Currency pair is:{0} the Bid is:{1} and the Ask is :{2}'.format(decoded_msg['D'][i]['M'], decoded_msg['D'][i]['B'], decoded_msg['D'][i]['A']))

Resources